184 Commits

Author SHA1 Message Date
Alexey
c0357b2890 Merge pull request #149 from vladon/fix/ci-deprecated-actions-rs
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (rust) (push) Has been cancelled
Rust / Build (push) Has been cancelled
fix(ci): replace deprecated actions-rs/cargo with direct cross commands
2026-02-18 22:02:16 +03:00
Vladislav Yaroslavlev
4f7f7d6880 fix(ci): replace deprecated actions-rs/cargo with direct cross commands
The actions-rs organization has been archived and is no longer available.
Replace the deprecated action with direct cross installation and build commands.
2026-02-18 21:49:42 +03:00
Alexey
efba10f839 Update README.md 2026-02-18 21:34:04 +03:00
Alexey
6ba12f35d0 Update README.md 2026-02-18 21:31:58 +03:00
Alexey
6a57c23700 Update README.md 2026-02-18 20:56:03 +03:00
Alexey
94b85afbc5 Update Cargo.toml 2026-02-18 20:25:17 +03:00
Alexey
cf717032a1 Merge pull request #144 from telemt/flow
ME Polishing
2026-02-18 20:05:15 +03:00
Alexey
d905de2dad Nonce in Log only in DEBUG
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 20:02:43 +03:00
Alexey
c7bd1c98e7 Autofallback on ME-Init 2026-02-18 19:50:16 +03:00
Alexey
d3302d77d2 Blackmagics... 2026-02-18 19:49:19 +03:00
Alexey
df4494c37a New reroute algo + flush() optimized + new IPV6 Parser
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 19:08:27 +03:00
Alexey
b84189b21b Update ROADMAP.md 2026-02-18 19:04:39 +03:00
Alexey
9243661f56 Update ROADMAP.md 2026-02-18 18:59:54 +03:00
Alexey
bffe97b2b7 Merge pull request #143 from telemt/plannung
Create ROADMAP.md
2026-02-18 18:52:25 +03:00
Alexey
bee1dd97ee Create ROADMAP.md 2026-02-18 17:53:32 +03:00
Alexey
16670e36f5 Merge pull request #138 from LinFor/LinFor-patch-1
Just a very simple Grafana dashboard
2026-02-18 14:13:41 +03:00
Alexey
5dad663b25 Autobuild: merge pull request #123 from vladon/git-action-for-build-for-x86_64-and-aarch64
Add GitHub Actions release workflow for multi-platform builds
2026-02-18 13:43:04 +03:00
LinFor
8375608aaa Create grafana-dashboard.json
Just a simple Grafana dashboard
2026-02-18 12:26:40 +03:00
Vladislav Yaroslavlev
0057377ac6 Fix CodeQL warnings: add permissions and pin action versions 2026-02-18 11:38:20 +03:00
Alexey
078ed65a0e Update Cargo.toml 2026-02-18 06:38:01 +03:00
Alexey
9872f0ed1b Update Cargo.toml 2026-02-18 06:09:55 +03:00
Alexey
fb0cb54776 Merge pull request #133 from telemt/flow
New [network] section + ME Fixes + small bugs coverage
2026-02-18 06:09:36 +03:00
Alexey
67bae1cf2a [network] in upstream
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 06:02:24 +03:00
Alexey
eb9ac7fae4 ME Fixes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 06:01:52 +03:00
Alexey
8046381939 [network] in main
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 06:01:08 +03:00
Alexey
650f9fd2a4 [network] in docs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 06:00:21 +03:00
Alexey
d4ebc7b5c6 New [network]
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-18 05:59:58 +03:00
Alexey
7a4ccf8e82 Update Cargo.toml 2026-02-18 04:24:16 +03:00
Alexey
73b40d386a Merge pull request #121 from vladon/git-action-for-build-n-test-every-pr
Add GitHub Actions workflow for build and test on every PR
2026-02-17 21:03:52 +03:00
Vladislav Yaroslavlev
3206ce50bb add manual workflow run 2026-02-17 18:17:14 +03:00
Vladislav Yaroslavlev
bdccb866fe git action for build binaries 2026-02-17 17:59:59 +03:00
Vladislav Yaroslavlev
9b5b382593 dont fail on loop error 2026-02-17 17:00:17 +03:00
Vladislav Yaroslavlev
9886c9a8e7 use -W warnings for clippy 2026-02-17 16:41:38 +03:00
Vladislav Yaroslavlev
cb3d32cc89 comment -D warnings for clippy 2026-02-17 16:35:03 +03:00
Vladislav Yaroslavlev
010eb5270f add git action to build and test every PR 2026-02-17 16:17:30 +03:00
Alexey
e33092530d Merge pull request #117 from vladon/update-cargo-lock
chore: update Cargo.lock with latest dependencies
2026-02-17 15:19:19 +03:00
Alexey
e7d649b57f Merge pull request #116 from An0nX/patch-1
feat: production system prompt — scope control, structured output, decision process
2026-02-17 14:17:28 +03:00
Vladislav Yaroslavlev
5f3d089003 chore: update Cargo.lock with latest dependencies
- Add h2 0.4.13 dependency
- Add httpdate 1.0.3 dependency
- Update hyper to include h2 and httpdate features
- Update tokio-util with additional futures and hashbrown dependencies
2026-02-17 12:49:02 +03:00
An0nX
4322509657 feat: rewrite system prompt with scope control, response format, and decision process
Rewrite the system prompt for production Rust codebase assistance.

Key changes:
- Add Priority Resolution (Section 0) implementing "Boy Scout Rule" with
  explicit scope control: coordinated style fixes are always in scope,
  architectural changes require explicit approval
- Add role definition as senior Rust systems engineer with strict code
  review responsibilities
- Rewrite negative constraints ("DO NOT") as positive instructions
  throughout all sections for better model adherence
- Add structured decision process for complex changes (Section 8):
  clarify → assess → propose → implement → verify
- Add context awareness rules (Section 9) for partial code handling
- Add mandatory response format (Section 10) with two-section structure:
  Reasoning (Russian) and Changes (English code)
- Add language policy: code/comments/commits in English,
  reasoning in Russian
- Add out-of-scope observations reporting mechanism — model reports
  issues it finds but is not allowed to fix
- Add splitting protocol for responses exceeding output limits
- Add file size thresholds for full-file vs contextual-diff responses
  (200 lines boundary)
- Preserve permission for todo!() and unimplemented!() as idiomatic
  Rust markers
- Preserve all existing rules: file size limits, formatting preservation,
  warning/dead-code protection, architectural integrity, git discipline
2026-02-17 12:42:03 +03:00
Alexey
43990c9dc9 Merge pull request #113 from telemt/me-fixes
Me fixes
2026-02-17 04:26:20 +03:00
Alexey
c03db683a5 Improved perf for ME
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-17 04:16:16 +03:00
Alexey
168fd59187 Fixed critical ME Problems 2026-02-17 03:40:39 +03:00
Alexey
8bd02d8099 Merge pull request #111 from VeryBigSad/feat/metrics-endpoint
Add Prometheus /metrics HTTP endpoint
2026-02-17 01:39:29 +03:00
Mikhail
a1db082ec0 Add Prometheus /metrics HTTP endpoint
Wire up unused metrics_port/metrics_whitelist config into working
HTTP server exposing proxy stats in Prometheus text format.
2026-02-17 01:24:49 +03:00
Alexey
9b9c11e7ab Merge pull request #110 from telemt/neurosl0pe
Create AGENTS_SYSTEM_PROMT.md
2026-02-16 23:41:59 +03:00
Alexey
274b9d5e94 Update AGENTS_SYSTEM_PROMT.md 2026-02-16 23:34:52 +03:00
Alexey
d888df6382 Update AGENTS.md 2026-02-16 23:33:09 +03:00
Alexey
011b9a3cbf Create AGENTS_SYSTEM_PROMT.md 2026-02-16 23:30:46 +03:00
Alexey
d67a587f3d Merge pull request #106 from vladon/docs/update-announce-readme
docs: update README with new 'announce' parameter
2026-02-16 22:33:25 +03:00
Vladislav Yaroslavlev
478fc5dd89 docs: update README with new 'announce' parameter
Replace deprecated 'announce_ip' example with new 'announce' parameter
that supports both hostnames and IP addresses.
2026-02-16 18:51:21 +03:00
Alexey
a0e7210dff Merge pull request #100 from vladon/feature/announce-hostname
feat: extend announce_ip to accept hostnames
2026-02-16 17:36:22 +03:00
vladon
16b5dc56f0 feat: extend announce_ip to accept hostnames
Add new 'announce' field to ListenerConfig that accepts both IP addresses
and hostnames for proxy link generation. The old 'announce_ip' field is
deprecated but still supported via automatic migration.

Changes:
- Add 'announce: Option<String>' field to ListenerConfig
- Add migration logic: announce_ip → announce if announce not set
- Update main.rs to use announce field for link generation
- Support both hostnames (e.g., 'proxy.example.com') and IPs

Backward compatible: existing configs using announce_ip continue to work.
2026-02-16 17:26:46 +03:00
vladon
303a6896bf AGENTS.md 2026-02-16 16:59:29 +03:00
Alexey
9e84528801 Update main.rs 2026-02-16 15:48:22 +03:00
Alexey
685c228190 Update main.rs 2026-02-16 15:16:26 +03:00
Alexey
febe4d1ac0 Merge pull request #98 from telemt/me-ping
ME Ping in log
2026-02-16 12:25:25 +03:00
Alexey
e4f90cd7c1 ME Ping in log 2026-02-16 12:10:59 +03:00
Alexey
3013291ea0 Merge pull request #97 from AndreyAkifev/main
Fix ME relay HOL and reduce per-frame flush overhead
2026-02-16 10:29:40 +03:00
Alexey
5d1dce7989 Merge pull request #95 from Katze-942/main-fix
Fix: public_host/public_port + unix socket
2026-02-16 10:28:35 +03:00
AndreyAkifev
864f7fa9a5 Merge branch 'telemt:main' into main 2026-02-16 08:51:26 +03:00
Andrey Akifev
e54fb3fffc Reduce per-frame flush overhead 2026-02-16 12:49:49 +07:00
Andrey Akifev
dddf9f30dc Fix HOL 2026-02-16 12:49:16 +07:00
Жора Змейкин
3091b5168f Fix: public_host/public_port + unix socket 2026-02-16 04:22:26 +03:00
Alexey
ddc91c2d66 Merge pull request #93 from sou1jacker/main
Fix "Read-only file system" and "Permission denied" errors for proxy-secret cache
2026-02-16 02:49:25 +03:00
Артур
8072a97f7e Modify docker-compose for tmpfs
Updated volume path for config.toml and added tmpfs configuration.
2026-02-16 02:03:11 +03:00
Alexey
558155ffaa Merge pull request #92 from An0nX/patch-1
Refactor dc.py: OOP architecture, strict typing, dataclass model
2026-02-16 00:49:39 +03:00
An0nX
ed329c2075 refactor: rewrite dc.py with OOP, strict typing, and dataclass model
- Replace procedural logic with TelegramDCChecker class
- Introduce frozen DCServer dataclass with slots for DC option parsing
- Add full type hints
- Add docstrings to all classes and methods
- Use itertools.groupby for DC grouping instead of manual dict building
- Use pathlib.Path for file output
2026-02-16 00:38:13 +03:00
Alexey
305c088bb7 Grabbing unknown dc into unknown-dc.txt 2026-02-15 23:59:53 +03:00
Alexey
debdbfd73c Ping for [dc_overrides]
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 23:46:49 +03:00
Alexey
904c17c1b3 DC=203 by default + IP Autodetect by STUN
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 23:30:21 +03:00
artemws
4a80bc8988 Refactor connectivity logging for upstream results 2026-02-15 22:33:25 +03:00
Alexey
f9c41ab703 Update rust.yml 2026-02-15 19:32:29 +03:00
Alexey
2112ba22f1 Update rust.yml 2026-02-15 19:31:23 +03:00
Alexey
fbe9277f86 Update README.md 2026-02-15 18:12:37 +03:00
Alexey
d1348e809f Update README.md 2026-02-15 18:09:54 +03:00
Alexey
533613886a Update README.md 2026-02-15 17:34:47 +03:00
Alexey
84f8b786e7 Update README.md 2026-02-15 17:29:52 +03:00
artemws
32bc3e1387 Refactor client handshake handling for clarity 2026-02-15 16:30:41 +03:00
artemws
0fa5914501 Add Unix socket listener support 2026-02-15 16:30:41 +03:00
Alexey
9b790c7bf4 Update README.md 2026-02-15 15:48:42 +03:00
Alexey
eda365c21f Update README.md 2026-02-15 15:46:24 +03:00
Alexey
8de1318c9c Update README.md 2026-02-15 15:35:44 +03:00
Alexey
7e566fd655 Update README.md 2026-02-15 14:46:15 +03:00
Alexey
a80db2ddbc Merge pull request #81 from telemt/3.0.0
3.0.0 Anschluss
2026-02-15 14:18:44 +03:00
Alexey
0694183ca6 Num_bigint + Num_traits Fix 2026-02-15 14:15:56 +03:00
Alexey
1f9fb29a9b Update config.toml
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 14:07:16 +03:00
Alexey
eccc69b79c Merge branch '3.0.0' of https://github.com/telemt/telemt into 3.0.0 2026-02-15 14:02:15 +03:00
Alexey
da108b2d8c Middle Proxy läuft wie auf Schienen...
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 14:02:00 +03:00
Alexey
9d94f55cdc Update Cargo.toml 2026-02-15 13:20:19 +03:00
Alexey
94a7058cc6 Middle Proxy Minimal
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 13:14:50 +03:00
Alexey
3d2e996cea Delete telemt
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 12:35:23 +03:00
Alexey
f2455c9cb1 Middle-End Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 12:30:40 +03:00
Alexey
427c7dd375 Deprecated failed KDF
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 12:29:34 +03:00
Alexey
e911a21a93 New hash in tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-15 12:29:08 +03:00
Alexey
edabad87d7 Merge pull request #78 from artemws/main
Disable color logs
2026-02-15 11:28:40 +03:00
artemws
2a65d29e3b Configure color output based on user settings
Added conditional color output configuration for logging.
2026-02-15 10:12:56 +02:00
artemws
c837a9b0c6 Add disable_colors field to GeneralConfig
Add option to disable colored output in logs
2026-02-15 10:12:33 +02:00
Alexey
f7618416b6 Merge pull request #77 from telemt/revert-68-unix-socket
Revert "Unix socket listener + reverse proxy improvements"
2026-02-15 10:09:13 +03:00
Alexey
0663e71c52 Revert "Unix socket listener + reverse proxy improvements" 2026-02-15 10:09:03 +03:00
Alexey
0599a6ec8c Merge pull request #76 from telemt/revert-72-main-fix
Revert "Main fix"
2026-02-15 10:08:34 +03:00
Alexey
b2d36aac19 Revert "Main fix" 2026-02-15 10:08:20 +03:00
Alexey
3d88ec5992 Merge pull request #74 from telemt/codeql-tuning
Update codeql.yml
2026-02-15 03:36:53 +03:00
Alexey
a693ed1e33 Merge pull request #72 from telemt/main-fix
Main fix
2026-02-15 03:36:25 +03:00
Alexey
911a504e16 Update main.rs 2026-02-15 03:34:24 +03:00
Alexey
56cd0cd1a9 Update client.rs 2026-02-15 03:27:53 +03:00
Alexey
358ad65d5f Update client.rs 2026-02-15 03:24:20 +03:00
Alexey
2f5df6ade0 Update codeql.yml 2026-02-15 03:20:19 +03:00
Alexey
e3b7be81e7 Update main.rs 2026-02-15 03:18:40 +03:00
Alexey
9a25e8e810 Update client.rs 2026-02-15 03:17:45 +03:00
Alexey
1a6b39b829 Merge pull request #68 from Katze-942/unix-socket
Unix socket listener + reverse proxy improvements
2026-02-15 02:48:39 +03:00
Alexey
a419cbbcf3 Merge branch 'main' into unix-socket 2026-02-15 02:48:24 +03:00
Alexey
b97ea1293b Merge pull request #69 from artemws/main
Unique IP address restrict for users
2026-02-15 00:24:20 +03:00
artemws
5f54eb8270 Comment out user_max_unique_ips setting
Comment out user_max_unique_ips configuration
2026-02-14 23:04:15 +02:00
artemws
06161abbbc Implement IP tracking and user limit checks
Added IP tracking and cleanup functionality for users.
2026-02-14 23:02:16 +02:00
artemws
aee549f745 Integrate IP Tracker for user IP management
Added UserIpTracker for managing user IP limits.
2026-02-14 23:01:43 +02:00
artemws
50ec753c05 Add user_max_unique_ips to configuration 2026-02-14 23:01:09 +02:00
artemws
cf34c7e75c Add files via upload 2026-02-14 23:00:26 +02:00
Жора Змейкин
572e07a7fd Unix socket listener + reverse proxy improvements 2026-02-14 23:29:39 +03:00
Alexey
4b5270137b Merge pull request #67 from telemt/main-dc-overrides
Bumped version + DC Overrides
2026-02-14 22:47:33 +03:00
Alexey
246230c924 Bumped version + DC Overrides 2026-02-14 22:46:00 +03:00
Alexey
21416af153 Merge pull request #66 from telemt/2.0.0.0-build
2.0.0.0 Build, Closing Branch
2026-02-14 22:34:13 +03:00
Alexey
b03312fa2e Merge pull request #65 from telemt/2.0.0.0-h
2.0.0.1
2026-02-14 22:20:43 +03:00
Alexey
bcdbf033b2 Delete middle_proxy.rs 2026-02-14 22:15:41 +03:00
Alexey
0a054c4a01 Find DC Method in Python
Co-Authored-By: artemws <59208085+artemws@users.noreply.github.com>
2026-02-14 21:55:29 +03:00
Alexey
eae7ad43d9 Merge pull request #63 from telemt/main-emergency
Update README.md
2026-02-14 20:40:03 +03:00
Alexey
0894ef0089 Update README.md 2026-02-14 20:39:34 +03:00
Alexey
954916960b Merge pull request #62 from telemt/main-emergency
Update README.md
2026-02-14 20:36:23 +03:00
Alexey
91d16b96ee Update README.md 2026-02-14 20:35:54 +03:00
Alexey
4bbadbc764 Merge pull request #41 from vmax/feature/show-all-links
feature: support show_links = "*"
2026-02-14 18:29:05 +03:00
Alexey
e4272ac35c Merge pull request #44 from telemt/dependabot/cargo/lru-0.16.3
Bump lru from 0.12.5 to 0.16.3
2026-02-14 13:26:34 +03:00
Alexey
7f8cde8317 NAT + STUN Probes... 2026-02-14 12:44:20 +03:00
Alexey
46ee91c6b7 File descriptor limits for systemd: merge pull request #57 from sou1jacker/main
"Too many open files" - add file descriptor limits for systemd & Docker (fixes telemt#56)
2026-02-14 12:37:31 +03:00
Alexey
e32d8e6c7d ME Diagnostics 2026-02-14 04:19:44 +03:00
Артур
ad553f8fbb docs: add ulimits to docker-compose.yml (fixes #56) 2026-02-14 01:59:30 +03:00
Alexey
d405756b94 HOL Minimized + Random conn_id + Target DC Magics
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-14 01:52:49 +03:00
Артур
c0b4129209 docs: add file descriptor limits for systemd and Docker (fixes #56) 2026-02-14 01:51:29 +03:00
Alexey
a8c3128c50 Middle Proxy Magics
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-14 01:51:10 +03:00
Alexey
70859aa5cf Middle Proxy is so real 2026-02-14 01:36:14 +03:00
Max Vorobev
fc47e4d584 feature: support show_links = "*" 2026-02-14 01:02:47 +03:00
Alexey
9b850b0bfb IP Version Superfallback 2026-02-14 00:30:09 +03:00
Alexey
32b16439c8 Merge pull request #55 from telemt/katze-942-ipv6
Update config.toml
2026-02-13 23:47:38 +03:00
Alexey
fd27449a26 Update config.toml 2026-02-13 23:47:26 +03:00
Alexey
3d13301711 Added Docker support, updated README.md: merge pull request #54 from sou1jacker/main
Added Docker support, updated README.md
2026-02-13 21:37:37 +03:00
sou1jacker
963ec7206b Added Docker support, updated README.md 2026-02-13 21:19:23 +03:00
Alexey
de28655dd2 Middle Proxy Fixes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 16:09:33 +03:00
Alexey
e62b41ae64 RPC Flags Fixes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 14:28:47 +03:00
Alexey
f1c1f42de8 Key derivation + me_health_monitor + QuickACK
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 12:51:49 +03:00
Alexey
a494dfa9eb Middle Proxy Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 03:51:36 +03:00
Alexey
9047511256 Merge pull request #46 from telemt/codeql-tuning
CodeQL Fixes
2026-02-13 03:40:55 +03:00
Alexey
4ba907fdcd CodeQL Fixes 2026-02-13 03:39:59 +03:00
Alexey
dae19c29a0 Merge pull request #45 from telemt/codeql-tuning-1
Update codeql-config.yml
2026-02-13 03:37:09 +03:00
Alexey
25530c8c44 Update codeql-config.yml 2026-02-13 03:36:51 +03:00
dependabot[bot]
aee44d3af2 Bump lru from 0.12.5 to 0.16.3
Bumps [lru](https://github.com/jeromefroe/lru-rs) from 0.12.5 to 0.16.3.
- [Changelog](https://github.com/jeromefroe/lru-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jeromefroe/lru-rs/compare/0.12.5...0.16.3)

---
updated-dependencies:
- dependency-name: lru
  dependency-version: 0.16.3
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-13 00:31:52 +00:00
Alexey
714d83bea1 Merge pull request #43 from telemt/codeql-tuning
Updated codeql-config.yml
2026-02-13 03:11:21 +03:00
Alexey
e1bfe69b76 Updated codeql-config.yml 2026-02-13 03:11:02 +03:00
Alexey
e6bf7ac40e Merge pull request #42 from telemt/codeql-tuning
Codeql tuning
2026-02-13 03:02:08 +03:00
Alexey
889a5fa19b Add mask_unix_sock for [censorship] masking: merge pull request #33 from Katze-942/main
Add mask_unix_sock for [censorship] masking
2026-02-12 21:30:51 +03:00
Жора Змейкин
d8ff958481 Add mask_unix_sock for censorship masking via Unix socket 2026-02-12 21:11:20 +03:00
Alexey
28ee74787b Merge pull request #36 from telemt/1.2.0.3
New Relay on Tokio Copy Bidirectional
2026-02-12 20:34:35 +03:00
Alexey
a688bfe22f New Relay on Tokio Copy Bidirectional 2026-02-12 20:20:01 +03:00
Alexey
91eea914b3 Update codeql.yml 2026-02-12 19:00:12 +03:00
Alexey
3ba97a08fa Update codeql.yml 2026-02-12 18:58:42 +03:00
Alexey
6e445be108 CodeQL Tuning 2026-02-12 18:58:03 +03:00
Alexey
3c6752644a Create codeql.yml 2026-02-12 18:56:08 +03:00
Alexey
9bd12f6acb 1.2.0.2 Special DC support: merge pull request #32 from telemt/1.2.0.2
1.2.0.2 Special DC support
2026-02-12 18:46:40 +03:00
Alexey
61581203c4 Semaphore + Async Magics for Defcluster
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:38:05 +03:00
Alexey
84668e671e Default Cluster Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:25:41 +03:00
Alexey
5bde202866 Startup logging refactoring: merge pull request #26 from Katze-942/main
Startup logging refactoring
2026-02-12 11:46:22 +03:00
Жора Змейкин
9304d5256a Refactor startup logging
Move all startup output (DC pings, proxy links) from println!() to
      info!() for consistent tracing format. Add reload::Layer so startup
      messages stay visible even in silent mode.
2026-02-12 05:14:23 +03:00
Alexey
364bc6e278 Merge pull request #21 from telemt/1.2.0.0
1.2.0.0
2026-02-11 17:00:46 +03:00
Alexey
e83db704b7 Pull-up 2026-02-11 16:55:18 +03:00
Alexey
acf90043eb Merge pull request #15 from telemt/main-emergency
Update README.md
2026-02-11 00:56:12 +03:00
Alexey
0011e20653 Update README.md 2026-02-11 00:55:27 +03:00
Alexey
41fb307858 Merge pull request #14 from telemt/main-emergency
Update README.md
2026-02-11 00:41:30 +03:00
Alexey
6a78c44d2e Update README.md 2026-02-11 00:41:08 +03:00
Alexey
be9c9858ac Merge pull request #13 from telemt/main-emergency
Main emergency
2026-02-11 00:39:45 +03:00
Alexey
2fa8d85b4c Update README.md 2026-02-11 00:31:45 +03:00
Alexey
310666fd44 Update README.md 2026-02-11 00:31:02 +03:00
Alexey
6cafee153a Fire-and-Forgot™ Draft
- Added fire-and-forget ignition via `--init` CLI command:
  - New `mod cli;` module handling installation logic
  - Extended `parse_cli()` to process `--init` flag (runs synchronously before tokio runtime)
  - Expanded `--help` output with installation options

- `--init` command functionality:
  - Generates random secret if not provided via `--secret`
  - Creates `/etc/telemt/config.toml` from template with user-provided or default parameters (`--port`, `--domain`, `--user`, `--config-dir`)
  - Creates hardened systemd unit `/etc/systemd/system/telemt.service` with security features:
    - `NoNewPrivileges=true`
    - `ProtectSystem=strict`
    - `PrivateTmp=true`
  - Runs `systemctl enable --now telemt.service`
  - Outputs `tg://` proxy links for the running service

- Implementation approach:
  - `--init` handled at the very start of `main()` before any async context
  - Uses blocking operations throughout (file I/O, `std::process::Command` for systemctl)
  - IP detection for tg:// links performed via blocking HTTP request
  - Command exits after installation without entering normal proxy runtime

- New CLI parameters for installation:
  - `--port` - listening port (default: 443)
  - `--domain` - TLS domain (default: auto-detected)
  - `--secret` - custom secret (default: randomly generated)
  - `--user` - systemd service user (default: telemt)
  - `--config-dir` - configuration directory (default: /etc/telemt)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:31:49 +03:00
Alexey
32f60f34db Fix Stats + UpstreamState + EMA Latency Tracking
- Per-DC latency tracking in UpstreamState (array of 5 EMA instances, one per DC):
  - Added `dc_latency: [LatencyEma; 5]` – per‑DC tracking instead of a single global EMA
  - `effective_latency(dc_idx)` – returns DC‑specific latency, falls back to average if unavailable
  - `select_upstream(dc_idx)` – now performs latency‑weighted selection: effective_weight = config_weight × (1000 / latency_ms)
    - Example: two upstreams with equal config weight but latencies of 50ms and 200ms → selection probabilities become 80% / 20%
  - `connect(target, dc_idx)` – extended signature, dc_idx used for upstream selection and per‑DC RTT recording
  - All ping/health‑check operations now record RTT into `dc_latency[dc_zero_index]`
  - `upstream_manager.connect(dc_addr)` changed to `upstream_manager.connect(dc_addr, Some(success.dc_idx))` – DC index now participates in upstream selection and per‑DC RTT logging
  - `client.rs` – passes dc_idx when connecting to Telegram

- Summary: Upstream selection now accounts for per‑DC latency using the formula weight × (1000/ms). With multiple upstreams (e.g., direct + socks5), traffic automatically flows to the faster route for each specific DC. With a single upstream, the data is used for monitoring without affecting routing.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:24:12 +03:00
Alexey
158eae8d2a Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs
  - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>)
  - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>)

- Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval):
  - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT
  - `connect()` – measures RTT of each real connection and updates EMA
  - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error
  - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails
  - `DcPingResult` / `StartupPingResult` – public structures for display
  - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible)
  - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter
  - Exported `StartupPingResult` and `DcPingResult`

- Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:18:25 +03:00
Alexey
92cedabc81 Zeroize for key + log refactor + fix tests
- Fixed tests that failed to compile due to mismatched generic parameters of HandshakeResult:
  - Changed `HandshakeResult<i32>` to `HandshakeResult<i32, (), ()>`
  - Changed `HandshakeResult::BadClient` to `HandshakeResult::BadClient { reader: (), writer: () }`

- Added Zeroize for all structures holding key material:
  - AesCbc – key and IV are zeroized on drop
  - SecureRandomInner – PRNG output buffer is zeroized on drop; local key copy in constructor is zeroized immediately after being passed to the cipher
  - ObfuscationParams – all four key‑material fields are zeroized on drop
  - HandshakeSuccess – all four key‑material fields are zeroized on drop

- Added protocol‑requirement documentation for legacy hashes (CodeQL suppression) in hash.rs (MD5/SHA‑1)

- Added documentation for zeroize limitations of AesCtr (opaque cipher state) in aes.rs

- Implemented silent‑mode logging and refactored initialization:
  - Added LogLevel enum to config and CLI flags --silent / --log-level
  - Added parse_cli() to handle --silent, --log-level, --help
  - Restructured main.rs initialization order: CLI → config load → determine log level → init tracing
  - Errors before tracing initialization are printed via eprintln!
  - Proxy links (tg://) are printed via println! – always visible regardless of log level
  - Configuration summary and operational messages are logged via info! (suppressed in silent mode)
  - Connection processing errors are lowered to debug! (hidden in silent mode)
  - Warning about default tls_domain moved to main (after tracing init)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 19:49:41 +03:00
Alexey
b9428d9780 Antireplay on sliding window + SecureRandom 2026-02-07 18:26:44 +03:00
Alexey
5876f0c4d5 Update rust.yml 2026-02-07 17:58:10 +03:00
70 changed files with 13228 additions and 1558 deletions

19
.github/codeql/codeql-config.yml vendored Normal file
View File

@@ -0,0 +1,19 @@
name: "Rust without tests"
disable-default-queries: false
queries:
- uses: security-extended
- uses: security-and-quality
- uses: ./.github/codeql/queries
query-filters:
- exclude:
id:
- rust/unwrap-on-option
- rust/unwrap-on-result
- rust/expect-used
analysis:
dataflow:
default-precision: high

View File

@@ -0,0 +1,20 @@
import rust
predicate isTestOnly(Item i) {
exists(ConditionalCompilation cc |
cc.getItem() = i and
cc.getCfg().toString() = "test"
)
}
predicate hasTestAttribute(Item i) {
exists(Attribute a |
a.getItem() = i and
a.getName() = "test"
)
}
predicate isProductionCode(Item i) {
not isTestOnly(i) and
not hasTestAttribute(i)
}

4
.github/codeql/queries/qlpack.yml vendored Normal file
View File

@@ -0,0 +1,4 @@
name: rust-production-only
version: 0.0.1
dependencies:
codeql/rust-all: "*"

45
.github/workflows/codeql.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: "CodeQL Advanced"
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
schedule:
- cron: '0 0 * * 0'
jobs:
analyze:
name: Analyze (${{ matrix.language }})
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
permissions:
security-events: write
packages: read
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: actions
build-mode: none
- language: rust
build-mode: none
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
config-file: .github/codeql/codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{ matrix.language }}"

98
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,98 @@
name: Release
on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+' # Matches tags like 3.0.0, 3.1.2, etc.
workflow_dispatch: # Manual trigger from GitHub Actions UI
permissions:
contents: read
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build ${{ matrix.target }}
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-unknown-linux-gnu
artifact_name: telemt
asset_name: telemt-x86_64-linux
- target: aarch64-unknown-linux-gnu
artifact_name: telemt
asset_name: telemt-aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install stable Rust toolchain
uses: dtolnay/rust-toolchain@888c2e1ea69ab0d4330cbf0af1ecc7b68f368cc1 # v1
with:
toolchain: stable
targets: ${{ matrix.target }}
- name: Install cross-compilation tools
run: |
sudo apt-get update
sudo apt-get install -y gcc-aarch64-linux-gnu
- name: Cache cargo registry & build artifacts
uses: actions/cache@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-${{ matrix.target }}-cargo-
- name: Install cross
run: cargo install cross --git https://github.com/cross-rs/cross
- name: Build Release
run: cross build --release --target ${{ matrix.target }}
- name: Package binary
run: |
cd target/${{ matrix.target }}/release
tar -czvf ${{ matrix.asset_name }}.tar.gz ${{ matrix.artifact_name }}
sha256sum ${{ matrix.asset_name }}.tar.gz > ${{ matrix.asset_name }}.sha256
- name: Upload artifact
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
name: ${{ matrix.asset_name }}
path: |
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.tar.gz
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.sha256
release:
name: Create Release
needs: build
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
path: artifacts
- name: Create Release
uses: softprops/action-gh-release@c95fe1489396fe360a41fb53f90de6ddce8c4c8a # v2.2.1
with:
files: artifacts/**/*
generate_release_notes: true
draft: false
prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }}

View File

@@ -2,9 +2,9 @@ name: Rust
on:
push:
branches: [ main ]
branches: [ "*" ]
pull_request:
branches: [ main ]
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
@@ -14,6 +14,11 @@ jobs:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -37,5 +42,13 @@ jobs:
- name: Build Release
run: cargo build --release --verbose
- name: Run tests
run: cargo test --verbose
# clippy dont fail on warnings because of active development of telemt
# and many warnings
- name: Run clippy
run: cargo clippy -- --cap-lints warn
- name: Check for unused dependencies
run: cargo udeps || true

View File

@@ -0,0 +1,58 @@
# Architect Mode Rules for Telemt
## Architecture Overview
```mermaid
graph TB
subgraph Entry
Client[Clients] --> Listener[TCP/Unix Listener]
end
subgraph Proxy Layer
Listener --> ClientHandler[ClientHandler]
ClientHandler --> Handshake[Handshake Validator]
Handshake --> |Valid| Relay[Relay Layer]
Handshake --> |Invalid| Masking[Masking/TLS Fronting]
end
subgraph Transport
Relay --> MiddleProxy[Middle-End Proxy Pool]
Relay --> DirectRelay[Direct DC Relay]
MiddleProxy --> TelegramDC[Telegram DCs]
DirectRelay --> TelegramDC
end
```
## Module Dependencies
- [`src/main.rs`](src/main.rs) - Entry point, spawns all async tasks
- [`src/config/`](src/config/) - Configuration loading with auto-migration
- [`src/error.rs`](src/error.rs) - Error types, must be used by all modules
- [`src/crypto/`](src/crypto/) - AES, SHA, random number generation
- [`src/protocol/`](src/protocol/) - MTProto constants, frame encoding, obfuscation
- [`src/stream/`](src/stream/) - Stream wrappers, buffer pool, frame codecs
- [`src/proxy/`](src/proxy/) - Client handling, handshake, relay logic
- [`src/transport/`](src/transport/) - Upstream management, middle-proxy, SOCKS support
- [`src/stats/`](src/stats/) - Statistics and replay protection
- [`src/ip_tracker.rs`](src/ip_tracker.rs) - Per-user IP tracking
## Key Architectural Constraints
### Middle-End Proxy Mode
- Requires public IP on interface OR 1:1 NAT with STUN probing
- Uses separate `proxy-secret` from Telegram (NOT user secrets)
- Falls back to direct mode automatically on STUN mismatch
### TLS Fronting
- Invalid handshakes are transparently proxied to `mask_host`
- This is critical for DPI evasion - do not change this behavior
- `mask_unix_sock` and `mask_host` are mutually exclusive
### Stream Architecture
- Buffer pool is shared globally via Arc - prevents allocation storms
- Frame codecs implement tokio-util Encoder/Decoder traits
- State machine in [`src/stream/state.rs`](src/stream/state.rs) manages stream transitions
### Configuration Migration
- [`ProxyConfig::load()`](src/config/mod.rs:641) mutates config in-place
- New fields must have sensible defaults
- DC203 override is auto-injected for CDN/media support

View File

@@ -0,0 +1,23 @@
# Code Mode Rules for Telemt
## Error Handling
- Always use [`ProxyError`](src/error.rs:168) from [`src/error.rs`](src/error.rs) for proxy operations
- [`HandshakeResult<T,R,W>`](src/error.rs:292) returns streams on bad client - these MUST be returned for masking, never dropped
- Use [`Recoverable`](src/error.rs:110) trait to check if errors are retryable
## Configuration Changes
- [`ProxyConfig::load()`](src/config/mod.rs:641) auto-mutates config - new fields should have defaults
- DC203 override is auto-injected if missing - do not remove this behavior
- When adding config fields, add migration logic in [`ProxyConfig::load()`](src/config/mod.rs:641)
## Crypto Code
- [`SecureRandom`](src/crypto/random.rs) from [`src/crypto/random.rs`](src/crypto/random.rs) must be used for all crypto operations
- Never use `rand::thread_rng()` directly - use the shared `Arc<SecureRandom>`
## Stream Handling
- Buffer pool [`BufferPool`](src/stream/buffer_pool.rs) is shared via Arc - always use it instead of allocating
- Frame codecs in [`src/stream/frame_codec.rs`](src/stream/frame_codec.rs) implement tokio-util's Encoder/Decoder traits
## Testing
- Tests are inline in modules using `#[cfg(test)]`
- Use `cargo test --lib <module_name>` to run tests for specific modules

View File

@@ -0,0 +1,27 @@
# Debug Mode Rules for Telemt
## Logging
- `RUST_LOG` environment variable takes absolute priority over all config log levels
- Log levels: `trace`, `debug`, `info`, `warn`, `error`
- Use `RUST_LOG=debug cargo run` for detailed operational logs
- Use `RUST_LOG=trace cargo run` for full protocol-level debugging
## Middle-End Proxy Debugging
- Set `ME_DIAG=1` environment variable for high-precision cryptography diagnostics
- STUN probe results are logged at startup - check for mismatch between local and reflected IP
- If Middle-End fails, check `proxy_secret_path` points to valid file from https://core.telegram.org/getProxySecret
## Connection Issues
- DC connectivity is logged at startup with RTT measurements
- If DC ping fails, check `dc_overrides` for custom addresses
- Use `prefer_ipv6=false` in config if IPv6 is unreliable
## TLS Fronting Issues
- Invalid handshakes are proxied to `mask_host` - check this host is reachable
- `mask_unix_sock` and `mask_host` are mutually exclusive - only one can be set
- If `mask_unix_sock` is set, socket must exist before connections arrive
## Common Errors
- `ReplayAttack` - client replayed a handshake nonce, potential attack
- `TimeSkew` - client clock is off, can disable with `ignore_time_skew=true`
- `TgHandshakeTimeout` - upstream DC connection failed, check network

40
AGENTS.md Normal file
View File

@@ -0,0 +1,40 @@
# AGENTS.md
** Use general system promt from AGENTS_SYSTEM_PROMT.md **
** Additional techiques and architectury details are here **
This file provides guidance to agents when working with code in this repository.
## Build & Test Commands
```bash
cargo build --release # Production build
cargo test # Run all tests
cargo test --lib error # Run tests for specific module (error module)
cargo bench --bench crypto_bench # Run crypto benchmarks
cargo clippy -- -D warnings # Lint with clippy
```
## Project-Specific Conventions
### Rust Edition
- Uses **Rust edition 2024** (not 2021) - specified in Cargo.toml
### Error Handling Pattern
- Custom [`Recoverable`](src/error.rs:110) trait distinguishes recoverable vs fatal errors
- [`HandshakeResult<T,R,W>`](src/error.rs:292) returns streams on bad client for masking - do not drop them
- Always use [`ProxyError`](src/error.rs:168) from [`src/error.rs`](src/error.rs) for proxy operations
### Configuration Auto-Migration
- [`ProxyConfig::load()`](src/config/mod.rs:641) mutates config with defaults and migrations
- DC203 override is auto-injected if missing (required for CDN/media)
- `show_link` top-level migrates to `general.links.show`
### Middle-End Proxy Requirements
- Requires public IP on interface OR 1:1 NAT with STUN probing
- Falls back to direct mode on STUN/interface mismatch unless `stun_iface_mismatch_ignore=true`
- Proxy-secret from Telegram is separate from user secrets
### TLS Fronting Behavior
- Invalid handshakes are transparently proxied to `mask_host` for DPI evasion
- `fake_cert_len` is randomized at startup (1024-4096 bytes)
- `mask_unix_sock` and `mask_host` are mutually exclusive

207
AGENTS_SYSTEM_PROMT.md Normal file
View File

@@ -0,0 +1,207 @@
## System Prompt — Production Rust Codebase: Modification and Architecture Guidelines
You are a senior Rust systems engineer acting as a strict code reviewer and implementation partner. Your responses are precise, minimal, and architecturally sound. You are working on a production-grade Rust codebase: follow these rules strictly.
---
### 0. Priority Resolution — Scope Control
This section resolves conflicts between code quality enforcement and scope limitation.
When editing or extending existing code, you MUST audit the affected files and fix:
- Comment style violations (missing, non-English, decorative, trailing).
- Missing or incorrect documentation on public items.
- Comment placement issues (trailing comments → move above the code).
These are **coordinated changes** — they are always in scope.
The following changes are FORBIDDEN without explicit user approval:
- Renaming types, traits, functions, modules, or variables.
- Altering business logic, control flow, or data transformations.
- Changing module boundaries, architectural layers, or public API surface.
- Adding or removing functions, structs, enums, or trait implementations.
- Fixing compiler warnings or removing unused code.
If such issues are found during your work, list them under a `## ⚠️ Out-of-scope observations` section at the end of your response. Include file path, context, and a brief description. Do not apply these changes.
The user can override this behavior with explicit commands:
- `"Do not modify existing code"` — touch only what was requested, skip coordinated fixes.
- `"Make minimal changes"` — no coordinated fixes, narrowest possible diff.
- `"Fix everything"` — apply all coordinated fixes and out-of-scope observations.
---
### 1. Comments and Documentation
- All comments MUST be written in English.
- Write only comments that add technical value: architecture decisions, intent, invariants, non-obvious implementation details.
- Place all comments on separate lines above the relevant code.
- Use `///` doc-comments for public items. Use `//` for internal clarifications.
Correct example:
```rust
// Handles MTProto client authentication and establishes encrypted session state.
fn handle_authenticated_client(...) { ... }
```
Incorrect examples:
```rust
let x = 5; // set x to 5
```
```rust
// This function does stuff
fn do_stuff() { ... }
```
---
### 2. File Size and Module Structure
- Files MUST NOT exceed 350550 lines.
- If a file exceeds this limit, split it into submodules organized by responsibility (e.g., protocol, transport, state, handlers).
- Parent modules MUST declare and describe their submodules.
- Maintain clear architectural boundaries between modules.
Correct example:
```rust
// Client connection handling logic.
// Submodules:
// - handshake: MTProto handshake implementation
// - relay: traffic forwarding logic
// - state: client session state machine
pub mod handshake;
pub mod relay;
pub mod state;
```
Git discipline:
- Use local git for versioning and diffs.
- Write clear, descriptive commit messages in English that explain both *what* changed and *why*.
---
### 3. Formatting
- Preserve the existing formatting style of the project exactly as-is.
- Reformat code only when explicitly instructed to do so.
- Do not run `cargo fmt` unless explicitly instructed.
---
### 4. Change Safety and Validation
- If anything is unclear, STOP and ask specific, targeted questions before proceeding.
- List exactly what is ambiguous and offer possible interpretations for the user to choose from.
- Prefer clarification over assumptions. Do not guess intent, behavior, or missing requirements.
- Actively ask questions before making architectural or behavioral changes.
---
### 5. Warnings and Unused Code
- Leave all warnings, unused variables, functions, imports, and dead code untouched unless explicitly instructed to modify them.
- These may be intentional or part of work-in-progress code.
- `todo!()` and `unimplemented!()` are permitted and should not be removed or replaced unless explicitly instructed.
---
### 6. Architectural Integrity
- Preserve existing architecture unless explicitly instructed to refactor.
- Do not introduce hidden behavioral changes.
- Do not introduce implicit refactors.
- Keep changes minimal, isolated, and intentional.
---
### 7. When Modifying Code
You MUST:
- Maintain architectural consistency with the existing codebase.
- Document non-obvious logic with comments that describe *why*, not *what*.
- Limit changes strictly to the requested scope (plus coordinated fixes per Section 0).
- Keep all existing symbol names unless renaming is explicitly requested.
- Preserve global formatting as-is.
You MUST NOT:
- Use placeholders: no `// ... rest of code`, no `// implement here`, no `/* TODO */` stubs that replace existing working code. Write full, working implementation. If the implementation is unclear, ask first.
- Refactor code outside the requested scope.
- Make speculative improvements.
Note: `todo!()` and `unimplemented!()` are allowed as idiomatic Rust markers for genuinely unfinished code paths.
---
### 8. Decision Process for Complex Changes
When facing a non-trivial modification, follow this sequence:
1. **Clarify**: Restate the task in one sentence to confirm understanding.
2. **Assess impact**: Identify which modules, types, and invariants are affected.
3. **Propose**: Describe the intended change before implementing it.
4. **Implement**: Make the minimal, isolated change.
5. **Verify**: Explain why the change preserves existing behavior and architectural integrity.
---
### 9. Context Awareness
- When provided with partial code, assume the rest of the codebase exists and functions correctly unless stated otherwise.
- Reference existing types, functions, and module structures by their actual names as shown in the provided code.
- When the provided context is insufficient to make a safe change, request the missing context explicitly.
---
### 10. Response Format
#### Language Policy
- Code, comments, commit messages, documentation: **English**.
- Reasoning and explanations in response text: **Russian**.
#### Response Structure
Your response MUST consist of two sections:
**Section 1: `## Reasoning` (in Russian)**
- What needs to be done and why.
- Which files and modules are affected.
- Architectural decisions and their rationale.
- Potential risks or side effects.
**Section 2: `## Changes`**
- For each modified or created file: the filename on a separate line in backticks, followed by the code block.
- For files **under 200 lines**: return the full file with all changes applied.
- For files **over 200 lines**: return only the changed functions/blocks with at least 3 lines of surrounding context above and below. If the user requests the full file, provide it.
- New files: full file content.
- End with a suggested git commit message in English.
#### Reporting Out-of-Scope Issues
If during modification you discover issues outside the requested scope (potential bugs, unsafe code, architectural concerns, missing error handling, unused imports, dead code):
- Do not fix them silently.
- List them under `## ⚠️ Out-of-scope observations` at the end of your response.
- Include: file path, line/function context, brief description of the issue, and severity estimate.
#### Splitting Protocol
If the response exceeds the output limit:
1. End the current part with: **SPLIT: PART N — CONTINUE? (remaining: file_list)**
2. List the files that will be provided in subsequent parts.
3. Wait for user confirmation before continuing.
4. No single file may be split across parts.

2807
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,15 @@
[package]
name = "telemt"
version = "1.0.0"
edition = "2021"
rust-version = "1.75"
version = "3.0.4"
edition = "2024"
[dependencies]
# C
libc = "0.2"
# Async runtime
tokio = { version = "1.35", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] }
tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["full"] }
# Crypto
aes = "0.8"
@@ -20,42 +19,48 @@ sha2 = "0.10"
sha1 = "0.10"
md-5 = "0.10"
hmac = "0.12"
crc32fast = "1.3"
crc32fast = "1.4"
zeroize = { version = "1.8", features = ["derive"] }
# Network
socket2 = { version = "0.5", features = ["all"] }
rustls = "0.22"
# Serial
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.8"
# Utils
bytes = "1.5"
thiserror = "1.0"
bytes = "1.9"
thiserror = "2.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
parking_lot = "0.12"
dashmap = "5.5"
lru = "0.12"
rand = "0.8"
lru = "0.16"
rand = "0.9"
chrono = { version = "0.4", features = ["serde"] }
hex = "0.4"
base64 = "0.21"
base64 = "0.22"
url = "2.5"
regex = "1.10"
once_cell = "1.19"
regex = "1.11"
crossbeam-queue = "0.3"
num-bigint = "0.4"
num-traits = "0.2"
# HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false }
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
hyper = { version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
http-body-util = "0.1"
httpdate = "1.0"
[dev-dependencies]
tokio-test = "0.4"
criterion = "0.5"
proptest = "1.4"
futures = "0.3"
[[bench]]
name = "crypto_bench"
harness = false
harness = false

43
Dockerfile Normal file
View File

@@ -0,0 +1,43 @@
# ==========================
# Stage 1: Build
# ==========================
FROM rust:1.85-slim-bookworm AS builder
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
COPY Cargo.toml Cargo.lock* ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
cargo build --release 2>/dev/null || true && \
rm -rf src
COPY . .
RUN cargo build --release && strip target/release/telemt
# ==========================
# Stage 2: Runtime
# ==========================
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
RUN useradd -r -s /usr/sbin/nologin telemt
WORKDIR /app
COPY --from=builder /build/target/release/telemt /app/telemt
COPY config.toml /app/config.toml
RUN chown -R telemt:telemt /app
USER telemt
EXPOSE 443
EXPOSE 9090
ENTRYPOINT ["/app/telemt"]
CMD ["config.toml"]

137
README.md
View File

@@ -2,9 +2,59 @@
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
## NEWS and EMERGENCY
### ✈️ Telemt 3 is released!
<table>
<tr>
<td width="50%" valign="top">
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
### 🇷🇺 RU
18 февраля мы опубликовали `telemt 3.0.3`, он имеет:
- улучшенный механизм Middle-End Health Check
- высокоскоростное восстановление инициализации Middle-End
- меньше задержек на hot-path
- более корректную работу в Dualstack, а именно - IPv6 Middle-End
- аккуратное переподключение клиента без дрифта сессий между Middle-End
- автоматическая деградация на Direct-DC при массовой (>2 ME-DC-групп) недоступности Middle-End
- автодетект IP за NAT, при возможности - будет выполнен хендшейк с ME, при неудаче - автодеградация
- единственный известный специальный DC=203 уже добавлен в код: медиа загружаются с CDN в Direct-DC режиме
[Здесь вы можете найти релиз](https://github.com/telemt/telemt/releases/tag/3.0.3)
Если у вас есть компетенции в асинхронных сетевых приложениях, анализе трафика, реверс-инжиниринге или сетевых расследованиях - мы открыты к идеям и pull requests!
</td>
<td width="50%" valign="top">
### 🇬🇧 EN
On February 18, we released `telemt 3.0.3`. This version introduces:
- improved Middle-End Health Check method
- high-speed recovery of Middle-End init
- reduced latency on the hot path
- correct Dualstack support: proper handling of IPv6 Middle-End
- *clean* client reconnection without session "drift" between Middle-End
- automatic degradation to Direct-DC mode in case of large-scale (>2 ME-DC groups) Middle-End unavailability
- automatic public IP detection behind NAT; first - Middle-End handshake is performed, otherwise automatic degradation is applied
- known special DC=203 is now handled natively: media is delivered from the CDN via Direct-DC mode
[Release is available here](https://github.com/telemt/telemt/releases/tag/3.0.3)
If you have expertise in asynchronous network applications, traffic analysis, reverse engineering, or network forensics - we welcome ideas and pull requests!
</td>
</tr>
</table>
# Features
💥 The configuration structure has changed since version 1.1.0.0. change it in your environment!
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
⚓ Our ***Middle-End Pool*** is fastest by design in standard scenarios, compared to other implementations of connecting to the Middle-End Proxy: non dramatically, but usual
# GOTO
- [Features](#features)
@@ -24,7 +74,9 @@
- [Telegram Calls](#telegram-calls-via-mtproxy)
- [DPI](#how-does-dpi-see-mtproxy-tls)
- [Whitelist on Network Level](#whitelist-on-ip)
- [Too many open files](#too-many-open-files)
- [Build](#build)
- [Docker](#docker)
- [Why Rust?](#why-rust)
## Features
@@ -108,6 +160,7 @@ Type=simple
WorkingDirectory=/bin
ExecStart=/bin/telemt /etc/telemt.toml
Restart=on-failure
LimitNOFILE=65536
[Install]
WantedBy=multi-user.target
@@ -123,17 +176,20 @@ then Ctrl+X -> Y -> Enter to save
## Configuration
### Minimal Configuration for First Start
```toml
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
# === General Settings ===
[general]
# prefer_ipv6 is deprecated; use [network].prefer
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[network]
ipv4 = true
ipv6 = true # set false to disable, omit for auto
prefer = 4 # 4 or 6
multipath = false
[general.modes]
classic = false
secure = false
@@ -150,11 +206,19 @@ listen_addr_ipv6 = "::"
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
# announce = "my.hostname.tld" # Optional: hostname for tg:// links
# OR
# announce = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# Users to show in the startup log (tg:// links)
[general.links]
show = ["hello"] # Users to show in the startup log (tg:// links)
# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links
# public_port = 443 # Port for tg:// links (default: server.port)
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
@@ -168,6 +232,7 @@ tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
fake_cert_len = 2048
# === Access Control & Users ===
@@ -201,6 +266,10 @@ weight = 10
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
# === DC Address Overrides ===
# [dc_overrides]
# "203" = "91.105.192.100:443"
```
### Advanced
#### Adtag
@@ -356,6 +425,23 @@ Keep-Alive: timeout=60
- in China behind the Great Firewall
- in Russia on mobile networks, less in wired networks
- in Iran during "activity"
### Too many open files
- On a fresh Linux install the default open file limit is low; under load `telemt` may fail with `Accept error: Too many open files`
- **Systemd**: add `LimitNOFILE=65536` to the `[Service]` section (already included in the example above)
- **Docker**: add `--ulimit nofile=65536:65536` to your `docker run` command, or in `docker-compose.yml`:
```yaml
ulimits:
nofile:
soft: 65536
hard: 65536
```
- **System-wide** (optional): add to `/etc/security/limits.conf`:
```
* soft nofile 1048576
* hard nofile 1048576
root soft nofile 1048576
root hard nofile 1048576
```
## Build
@@ -374,9 +460,44 @@ chmod +x /bin/telemt
telemt config.toml
```
## Docker
**Quick start (Docker Compose)**
1. Edit `config.toml` in repo root (at least: port, users secrets, tls_domain)
2. Start container:
```bash
docker compose up -d --build
```
3. Check logs:
```bash
docker compose logs -f telemt
```
4. Stop:
```bash
docker compose down
```
**Notes**
- `docker-compose.yml` maps `./config.toml` to `/app/config.toml` (read-only)
- By default it publishes `443:443` and runs with dropped capabilities (only `NET_BIND_SERVICE` is added)
- If you really need host networking (usually only for some IPv6 setups) uncomment `network_mode: host`
**Run without Compose**
```bash
docker build -t telemt:local .
docker run --name telemt --restart unless-stopped \
-p 443:443 \
-e RUST_LOG=info \
-v "$PWD/config.toml:/app/config.toml:ro" \
--read-only \
--cap-drop ALL --cap-add NET_BIND_SERVICE \
--ulimit nofile=65536:65536 \
telemt:local
```
## Why Rust?
- Long-running reliability and idempotent behavior
- Rusts deterministic resource management - RAII
- Rust's deterministic resource management - RAII
- No garbage collector
- Memory safety and reduced attack surface
- Tokio's asynchronous architecture

34
ROADMAP.md Normal file
View File

@@ -0,0 +1,34 @@
### 3.0.0 Anschluss
- **Middle Proxy now is stable**, confirmed on canary-deploy over ~20 users
- Ad-tag now is working
- DC=203/CDN now is working over ME
- `getProxyConfig` and `ProxySecret` are automated
- Version order is now in format `3.0.0` - without Windows-style "microfixes"
### 3.0.1 Kabelsammler
- Handshake timeouts fixed
- Connectivity logging refactored
- Docker: tmpfs for ProxyConfig and ProxySecret
- Public Host and Port in config
- ME Relays Head-of-Line Blocking fixed
- ME Ping
### 3.0.2 Microtrencher
- New [network] section
- ME Fixes
- Small bugs coverage
### 3.0.3 Ausrutscher
- ME as stateful, no conn-id migration
- No `flush()` on datapath after RpcWriter
- Hightech parser for IPv6 without regexp
- `nat_probe = true` by default
- Timeout for `recv()` in STUN-client
- ConnRegistry review
- Dualstack emergency reconnect
### 3.0.4 Schneeflecken
- Only WARN and Links in Normal log
- Consistent IP-family detection
- Includes for config
- `nonce_frame_hex` in log only with `DEBUG`

View File

@@ -1,13 +1,23 @@
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
# === General Settings ===
[general]
# prefer_ipv6 is deprecated; use [network].prefer instead
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
use_middle_proxy = true
#ad_tag = "00000000000000000000000000000000"
[network]
# Enable/disable families; ipv6 = true/false/auto(None)
ipv4 = true
ipv6 = true
# prefer = 4 or 6
prefer = 4
multipath = false
# Log level: debug | verbose | normal | silent
# Can be overridden with --silent or --log-level CLI flags
# RUST_LOG env var takes absolute priority over all of these
log_level = "normal"
[general.modes]
classic = false
@@ -19,6 +29,8 @@ tls = true
port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# listen_unix_sock = "/var/run/telemt.sock" # Unix socket
# listen_unix_sock_perm = "0666" # Socket file permissions
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
@@ -30,6 +42,12 @@ ip = "0.0.0.0"
[[server.listeners]]
ip = "::"
# Users to show in the startup log (tg:// links)
[general.links]
show = ["hello"] # Users to show in the startup log (tg:// links)
# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links
# public_port = 443 # Port for tg:// links (default: server.port)
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
@@ -43,12 +61,13 @@ tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
@@ -58,21 +77,24 @@ hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_max_unique_ips]
# hello = 5
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# address = "127.0.0.1:1080"
# enabled = false
# weight = 1
# === DC Address Overrides ===
# [dc_overrides]
# "203" = "91.105.192.100:443"

29
docker-compose.yml Normal file
View File

@@ -0,0 +1,29 @@
services:
telemt:
build: .
container_name: telemt
restart: unless-stopped
ports:
- "443:443"
- "9090:9090"
# Allow caching 'proxy-secret' in read-only container
working_dir: /run/telemt
volumes:
- ./config.toml:/run/telemt/config.toml:ro
tmpfs:
- /run/telemt:rw,mode=1777,size=1m
environment:
- RUST_LOG=info
# Uncomment this line if you want to use host network for IPv6, but bridge is default and usually better
# network_mode: host
cap_drop:
- ALL
cap_add:
- NET_BIND_SERVICE # allow binding to port 443
read_only: true
security_opt:
- no-new-privileges:true
ulimits:
nofile:
soft: 65536
hard: 65536

307
src/cli.rs Normal file
View File

@@ -0,0 +1,307 @@
//! CLI commands: --init (fire-and-forget setup)
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rand::Rng;
/// Options for the init command
pub struct InitOptions {
pub port: u16,
pub domain: String,
pub secret: Option<String>,
pub username: String,
pub config_dir: PathBuf,
pub no_start: bool,
}
impl Default for InitOptions {
fn default() -> Self {
Self {
port: 443,
domain: "www.google.com".to_string(),
secret: None,
username: "user".to_string(),
config_dir: PathBuf::from("/etc/telemt"),
no_start: false,
}
}
}
/// Parse --init subcommand options from CLI args.
///
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
if !args.iter().any(|a| a == "--init") {
return None;
}
let mut opts = InitOptions::default();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--port" => {
i += 1;
if i < args.len() {
opts.port = args[i].parse().unwrap_or(443);
}
}
"--domain" => {
i += 1;
if i < args.len() {
opts.domain = args[i].clone();
}
}
"--secret" => {
i += 1;
if i < args.len() {
opts.secret = Some(args[i].clone());
}
}
"--user" => {
i += 1;
if i < args.len() {
opts.username = args[i].clone();
}
}
"--config-dir" => {
i += 1;
if i < args.len() {
opts.config_dir = PathBuf::from(&args[i]);
}
}
"--no-start" => {
opts.no_start = true;
}
_ => {}
}
i += 1;
}
Some(opts)
}
/// Run the fire-and-forget setup.
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[telemt] Fire-and-forget setup");
eprintln!();
// 1. Generate or validate secret
let secret = match opts.secret {
Some(s) => {
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
eprintln!("[error] Secret must be exactly 32 hex characters");
std::process::exit(1);
}
s
}
None => generate_secret(),
};
eprintln!("[+] Secret: {}", secret);
eprintln!("[+] User: {}", opts.username);
eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory
fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml");
// 3. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit
let exe_path = std::env::current_exe()
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path);
match fs::write(unit_path, &unit_content) {
Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display());
}
Err(e) => {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content);
// Still print links and config
print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(());
}
}
// 5. Reload systemd
run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start)
if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started");
// Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl")
.args(["is-active", "telemt.service"])
.output();
match status {
Ok(out) if out.status.success() => {
eprintln!("[+] Service is running");
}
_ => {
eprintln!("[!] Service may not have started correctly");
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
}
}
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service");
}
eprintln!();
// 8. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(())
}
fn generate_secret() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
hex::encode(bytes)
}
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!(
r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate
show_link = ["{username}"]
[general]
# prefer_ipv6 is deprecated; use [network].prefer
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
log_level = "normal"
[network]
ipv4 = true
ipv6 = true
prefer = 4
multipath = false
[general.modes]
classic = false
secure = false
tls = true
[server]
port = {port}
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
[[server.listeners]]
ip = "0.0.0.0"
[[server.listeners]]
ip = "::"
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
[censorship]
tls_domain = "{domain}"
mask = true
mask_port = 443
fake_cert_len = 2048
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
{username} = "{secret}"
[[upstreams]]
type = "direct"
enabled = true
weight = 10
"#,
username = username,
secret = secret,
port = port,
domain = domain,
)
}
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!(
r#"[Unit]
Description=Telemt MTProxy
Documentation=https://github.com/nicepkg/telemt
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart={exe} {config}
Restart=always
RestartSec=5
LimitNOFILE=65535
# Security hardening
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/etc/telemt
PrivateTmp=true
[Install]
WantedBy=multi-user.target
"#,
exe = exe_path.display(),
config = config_path.display(),
)
}
fn run_cmd(cmd: &str, args: &[&str]) {
match Command::new(cmd).args(args).output() {
Ok(output) => {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
}
}
Err(e) => {
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
}
}
}
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
let domain_hex = hex::encode(domain);
println!("=== Proxy Links ===");
println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex);
println!();
println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup.");
println!("Check: journalctl -u telemt.service | head -30");
println!("===================");
}

105
src/config/defaults.rs Normal file
View File

@@ -0,0 +1,105 @@
use std::net::IpAddr;
use std::collections::HashMap;
use serde::Deserialize;
// Helper defaults kept private to the config module.
pub(crate) fn default_true() -> bool {
true
}
pub(crate) fn default_port() -> u16 {
443
}
pub(crate) fn default_tls_domain() -> String {
"www.google.com".to_string()
}
pub(crate) fn default_mask_port() -> u16 {
443
}
pub(crate) fn default_fake_cert_len() -> usize {
2048
}
pub(crate) fn default_replay_check_len() -> usize {
65_536
}
pub(crate) fn default_replay_window_secs() -> u64 {
1800
}
pub(crate) fn default_handshake_timeout() -> u64 {
15
}
pub(crate) fn default_connect_timeout() -> u64 {
10
}
pub(crate) fn default_keepalive() -> u64 {
60
}
pub(crate) fn default_ack_timeout() -> u64 {
300
}
pub(crate) fn default_me_one_retry() -> u8 {
3
}
pub(crate) fn default_me_one_timeout() -> u64 {
1500
}
pub(crate) fn default_listen_addr() -> String {
"0.0.0.0".to_string()
}
pub(crate) fn default_weight() -> u16 {
1
}
pub(crate) fn default_metrics_whitelist() -> Vec<IpAddr> {
vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()]
}
pub(crate) fn default_prefer_4() -> u8 {
4
}
pub(crate) fn default_unknown_dc_log_path() -> Option<String> {
Some("unknown-dc.txt".to_string())
}
// Custom deserializer helpers
#[derive(Deserialize)]
#[serde(untagged)]
pub(crate) enum OneOrMany {
One(String),
Many(Vec<String>),
}
pub(crate) fn deserialize_dc_overrides<'de, D>(
deserializer: D,
) -> std::result::Result<HashMap<String, Vec<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let raw: HashMap<String, OneOrMany> = HashMap::deserialize(deserializer)?;
let mut out = HashMap::new();
for (dc, val) in raw {
let mut addrs = match val {
OneOrMany::One(s) => vec![s],
OneOrMany::Many(v) => v,
};
addrs.retain(|s| !s.trim().is_empty());
if !addrs.is_empty() {
out.insert(dc, addrs);
}
}
Ok(out)
}

295
src/config/load.rs Normal file
View File

@@ -0,0 +1,295 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::Path;
use rand::Rng;
use tracing::warn;
use serde::{Serialize, Deserialize};
use crate::error::{ProxyError, Result};
use super::defaults::*;
use super::types::*;
fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> {
if !net.ipv4 && matches!(net.ipv6, Some(false)) {
return Err(ProxyError::Config(
"Both ipv4 and ipv6 are disabled in [network]".to_string(),
));
}
if net.prefer != 4 && net.prefer != 6 {
return Err(ProxyError::Config(
"network.prefer must be 4 or 6".to_string(),
));
}
if !net.ipv4 && net.prefer == 4 {
warn!("prefer=4 but ipv4=false; forcing prefer=6");
net.prefer = 6;
}
if matches!(net.ipv6, Some(false)) && net.prefer == 6 {
warn!("prefer=6 but ipv6=false; forcing prefer=4");
net.prefer = 4;
}
Ok(())
}
// ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig {
#[serde(default)]
pub general: GeneralConfig,
#[serde(default)]
pub network: NetworkConfig,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub timeouts: TimeoutsConfig,
#[serde(default)]
pub censorship: AntiCensorshipConfig,
#[serde(default)]
pub access: AccessConfig,
#[serde(default)]
pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub show_link: ShowLink,
/// DC address overrides for non-standard DCs (CDN, media, test, etc.)
/// Keys are DC indices as strings, values are one or more "ip:port" addresses.
/// Matches the C implementation's `proxy_for <dc_id> <ip>:<port>` config directive.
/// Example in config.toml:
/// [dc_overrides]
/// "203" = ["149.154.175.100:443", "91.105.192.100:443"]
#[serde(default, deserialize_with = "deserialize_dc_overrides")]
pub dc_overrides: HashMap<String, Vec<String>>,
/// Default DC index (1-5) for unmapped non-standard DCs.
/// Matches the C implementation's `default <dc_id>` config directive.
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
#[serde(default)]
pub default_dc: Option<u8>,
}
impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content =
std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig =
toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets.
for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret {
user: user.clone(),
reason: "Must be 32 hex characters".to_string(),
});
}
}
// Validate tls_domain.
if config.censorship.tls_domain.is_empty() {
return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Validate mask_unix_sock.
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
if sock_path.is_empty() {
return Err(ProxyError::Config(
"mask_unix_sock cannot be empty".to_string(),
));
}
#[cfg(unix)]
if sock_path.len() > 107 {
return Err(ProxyError::Config(format!(
"mask_unix_sock path too long: {} bytes (max 107)",
sock_path.len()
)));
}
#[cfg(not(unix))]
return Err(ProxyError::Config(
"mask_unix_sock is only supported on Unix platforms".to_string(),
));
if config.censorship.mask_host.is_some() {
return Err(ProxyError::Config(
"mask_unix_sock and mask_host are mutually exclusive".to_string(),
));
}
}
// Default mask_host to tls_domain if not set and no unix socket configured.
if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
}
// Migration: prefer_ipv6 -> network.prefer.
if config.general.prefer_ipv6 {
if config.network.prefer == 4 {
config.network.prefer = 6;
}
warn!("prefer_ipv6 is deprecated, use [network].prefer = 6");
}
// Auto-enable NAT probe when Middle Proxy is requested.
if config.general.use_middle_proxy && !config.general.middle_proxy_nat_probe {
config.general.middle_proxy_nat_probe = true;
warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode");
}
validate_network_cfg(&mut config.network)?;
// Random fake_cert_len.
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Resolve listen_tcp: explicit value wins, otherwise auto-detect.
// If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided.
// If no unix socket → TCP always (backward compat).
let listen_tcp = config.server.listen_tcp.unwrap_or_else(|| {
if config.server.listen_unix_sock.is_some() {
// Unix socket present: TCP only if user explicitly set addresses or listeners.
config.server.listen_addr_ipv4.is_some()
|| !config.server.listeners.is_empty()
} else {
true
}
});
// Migration: Populate listeners if empty (skip when listen_tcp = false).
if config.server.listeners.is_empty() && listen_tcp {
let ipv4_str = config.server.listen_addr_ipv4
.as_deref()
.unwrap_or("0.0.0.0");
if let Ok(ipv4) = ipv4_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv4,
announce: None,
announce_ip: None,
});
}
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv6,
announce: None,
announce_ip: None,
});
}
}
}
// Migration: announce_ip → announce for each listener.
for listener in &mut config.server.listeners {
if listener.announce.is_none() && listener.announce_ip.is_some() {
listener.announce = Some(listener.announce_ip.unwrap().to_string());
}
}
// Migration: show_link (top-level) → general.links.show.
if !config.show_link.is_empty() && config.general.links.show.is_empty() {
config.general.links.show = config.show_link.clone();
}
// Migration: Populate upstreams if empty (Default Direct).
if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig {
upstream_type: UpstreamType::Direct { interface: None },
weight: 1,
enabled: true,
});
}
// Ensure default DC203 override is present.
config
.dc_overrides
.entry("203".to_string())
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
Ok(config)
}
pub fn validate(&self) -> Result<()> {
if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string()));
}
if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string()));
}
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(format!(
"Invalid tls_domain: '{}'. Must be a valid domain name",
self.censorship.tls_domain
)));
}
if let Some(tag) = &self.general.ad_tag {
let zeros = "00000000000000000000000000000000";
if tag == zeros {
warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel");
}
if tag.len() != 32 || tag.chars().any(|c| !c.is_ascii_hexdigit()) {
warn!("ad_tag is not a 32-char hex string; ensure you use value issued by @MTProxybot");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dc_overrides_allow_string_and_array() {
let toml = r#"
[dc_overrides]
"201" = "149.154.175.50:443"
"202" = ["149.154.167.51:443", "149.154.175.100:443"]
"#;
let cfg: ProxyConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.dc_overrides["201"], vec!["149.154.175.50:443"]);
assert_eq!(
cfg.dc_overrides["202"],
vec!["149.154.167.51:443", "149.154.175.100:443"]
);
}
#[test]
fn dc_overrides_inject_dc203_default() {
let toml = r#"
[general]
use_middle_proxy = false
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_dc_override_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert!(cfg
.dc_overrides
.get("203")
.map(|v| v.contains(&"91.105.192.100:443".to_string()))
.unwrap_or(false));
let _ = std::fs::remove_file(path);
}
}

View File

@@ -1,365 +1,8 @@
//! Configuration
//! Configuration.
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::Path;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result};
pub(crate) mod defaults;
mod types;
mod load;
// ============= Helper Defaults =============
fn default_true() -> bool { true }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_weight() -> u16 { 1 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes {
#[serde(default)]
pub classic: bool,
#[serde(default)]
pub secure: bool,
#[serde(default = "default_true")]
pub tls: bool,
}
impl Default for ProxyModes {
fn default() -> Self {
Self { classic: true, secure: true, tls: true }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
pub modes: ProxyModes,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
ad_tag: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
fake_cert_len: default_fake_cert_len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType {
Direct {
#[serde(default)]
interface: Option<String>,
},
Socks4 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
user_id: Option<String>,
},
Socks5 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
username: Option<String>,
#[serde(default)]
password: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamConfig {
#[serde(flatten)]
pub upstream_type: UpstreamType,
#[serde(default = "default_weight")]
pub weight: u16,
#[serde(default = "default_true")]
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListenerConfig {
pub ip: IpAddr,
#[serde(default)]
pub announce_ip: Option<IpAddr>,
}
// ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig {
#[serde(default)]
pub general: GeneralConfig,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub timeouts: TimeoutsConfig,
#[serde(default)]
pub censorship: AntiCensorshipConfig,
#[serde(default)]
pub access: AccessConfig,
#[serde(default)]
pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub show_link: Vec<String>,
}
impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)
.map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig = toml::from_str(&content)
.map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets
for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret {
user: user.clone(),
reason: "Must be 32 hex characters".to_string(),
});
}
}
// Validate tls_domain
if config.censorship.tls_domain.is_empty() {
return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Warn if using default tls_domain
if config.censorship.tls_domain == "www.google.com" {
tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml");
}
// Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() {
tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain);
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
}
// Random fake_cert_len
use rand::Rng;
config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
// Migration: Populate listeners if empty
if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv4,
announce_ip: None,
});
}
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv6,
announce_ip: None,
});
}
}
}
// Migration: Populate upstreams if empty (Default Direct)
if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig {
upstream_type: UpstreamType::Direct { interface: None },
weight: 1,
enabled: true,
});
}
Ok(config)
}
pub fn validate(&self) -> Result<()> {
if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string()));
}
if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string()));
}
// Validate tls_domain format (basic check)
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
));
}
Ok(())
}
}
pub use load::ProxyConfig;
pub use types::*;

514
src/config/types.rs Normal file
View File

@@ -0,0 +1,514 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
use super::defaults::*;
// ============= Log Level =============
/// Logging verbosity level.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
/// All messages including trace (trace + debug + info + warn + error).
Debug,
/// Detailed operational logs (debug + info + warn + error).
Verbose,
/// Standard operational logs (info + warn + error).
#[default]
Normal,
/// Minimal output: only warnings and errors (warn + error).
/// Startup messages (config, DC connectivity, proxy links) are always shown
/// via info! before the filter is applied.
Silent,
}
impl LogLevel {
/// Convert to tracing EnvFilter directive string.
pub fn to_filter_str(&self) -> &'static str {
match self {
LogLevel::Debug => "trace",
LogLevel::Verbose => "debug",
LogLevel::Normal => "info",
LogLevel::Silent => "warn",
}
}
/// Parse from a loose string (CLI argument).
pub fn from_str_loose(s: &str) -> Self {
match s.to_lowercase().as_str() {
"debug" | "trace" => LogLevel::Debug,
"verbose" => LogLevel::Verbose,
"normal" | "info" => LogLevel::Normal,
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
_ => LogLevel::Normal,
}
}
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogLevel::Debug => write!(f, "debug"),
LogLevel::Verbose => write!(f, "verbose"),
LogLevel::Normal => write!(f, "normal"),
LogLevel::Silent => write!(f, "silent"),
}
}
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes {
#[serde(default)]
pub classic: bool,
#[serde(default)]
pub secure: bool,
#[serde(default = "default_true")]
pub tls: bool,
}
impl Default for ProxyModes {
fn default() -> Self {
Self {
classic: true,
secure: true,
tls: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkConfig {
#[serde(default = "default_true")]
pub ipv4: bool,
/// None = auto-detect IPv6 availability.
#[serde(default)]
pub ipv6: Option<bool>,
/// 4 or 6.
#[serde(default = "default_prefer_4")]
pub prefer: u8,
#[serde(default)]
pub multipath: bool,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
ipv4: true,
ipv6: None,
prefer: 4,
multipath: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
pub modes: ProxyModes,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
/// Path to proxy-secret binary file (auto-downloaded if absent).
/// Infrastructure secret from https://core.telegram.org/getProxySecret.
#[serde(default)]
pub proxy_secret_path: Option<String>,
/// Public IP override for middle-proxy NAT environments.
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
#[serde(default)]
pub middle_proxy_nat_ip: Option<IpAddr>,
/// Enable STUN-based NAT probing to discover public IP:port for ME KDF.
#[serde(default)]
pub middle_proxy_nat_probe: bool,
/// Optional STUN server address (host:port) for NAT probing.
#[serde(default)]
pub middle_proxy_nat_stun: Option<String>,
/// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected).
#[serde(default)]
pub stun_iface_mismatch_ignore: bool,
/// Log unknown (non-standard) DC requests to a file (default: unknown-dc.txt). Set to null to disable.
#[serde(default = "default_unknown_dc_log_path")]
pub unknown_dc_log_path: Option<String>,
#[serde(default)]
pub log_level: LogLevel,
/// Disable colored output in logs (useful for files/systemd).
#[serde(default)]
pub disable_colors: bool,
/// [general.links] — proxy link generation overrides.
#[serde(default)]
pub links: LinksConfig,
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
ad_tag: None,
proxy_secret_path: None,
middle_proxy_nat_ip: None,
middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None,
stun_iface_mismatch_ignore: false,
unknown_dc_log_path: default_unknown_dc_log_path(),
log_level: LogLevel::Normal,
disable_colors: false,
links: LinksConfig::default(),
}
}
}
/// `[general.links]` — proxy link generation settings.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LinksConfig {
/// List of usernames whose tg:// links to display at startup.
/// `"*"` = all users, `["alice", "bob"]` = specific users.
#[serde(default)]
pub show: ShowLink,
/// Public hostname/IP for tg:// link generation (overrides detected IP).
#[serde(default)]
pub public_host: Option<String>,
/// Public port for tg:// link generation (overrides server.port).
#[serde(default)]
pub public_port: Option<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub listen_addr_ipv4: Option<String>,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
/// Unix socket file permissions (octal, e.g. "0666" or "0777").
/// Applied via chmod after bind. Default: no change (inherits umask).
#[serde(default)]
pub listen_unix_sock_perm: Option<String>,
/// Enable TCP listening. Default: true when no unix socket, false when
/// listen_unix_sock is set. Set explicitly to override auto-detection.
#[serde(default)]
pub listen_tcp: Option<bool>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
listen_addr_ipv4: Some(default_listen_addr()),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
listen_unix_sock_perm: None,
listen_tcp: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
/// Number of quick ME reconnect attempts for single-address DC.
#[serde(default = "default_me_one_retry")]
pub me_one_retry: u8,
/// Timeout per quick attempt in milliseconds for single-address DC.
#[serde(default = "default_me_one_timeout")]
pub me_one_timeout_ms: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
me_one_retry: default_me_one_retry(),
me_one_timeout_ms: default_me_one_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)]
pub mask_unix_sock: Option<String>,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
mask_unix_sock: None,
fake_cert_len: default_fake_cert_len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default)]
pub user_max_unique_ips: HashMap<String, usize>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert(
"default".to_string(),
"00000000000000000000000000000000".to_string(),
);
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
user_max_unique_ips: HashMap::new(),
replay_check_len: default_replay_check_len(),
replay_window_secs: default_replay_window_secs(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType {
Direct {
#[serde(default)]
interface: Option<String>,
},
Socks4 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
user_id: Option<String>,
},
Socks5 {
address: String,
#[serde(default)]
interface: Option<String>,
#[serde(default)]
username: Option<String>,
#[serde(default)]
password: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamConfig {
#[serde(flatten)]
pub upstream_type: UpstreamType,
#[serde(default = "default_weight")]
pub weight: u16,
#[serde(default = "default_true")]
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListenerConfig {
pub ip: IpAddr,
/// IP address or hostname to announce in proxy links.
/// Takes precedence over `announce_ip` if both are set.
#[serde(default)]
pub announce: Option<String>,
/// Deprecated: Use `announce` instead. IP address to announce in proxy links.
/// Migrated to `announce` automatically if `announce` is not set.
#[serde(default)]
pub announce_ip: Option<IpAddr>,
}
// ============= ShowLink =============
/// Controls which users' proxy links are displayed at startup.
///
/// In TOML, this can be:
/// - `show_link = "*"` — show links for all users
/// - `show_link = ["a", "b"]` — show links for specific users
/// - omitted — show no links (default)
#[derive(Debug, Clone)]
pub enum ShowLink {
/// Don't show any links (default when omitted).
None,
/// Show links for all configured users.
All,
/// Show links for specific users.
Specific(Vec<String>),
}
impl Default for ShowLink {
fn default() -> Self {
ShowLink::None
}
}
impl ShowLink {
/// Returns true if no links should be shown.
pub fn is_empty(&self) -> bool {
matches!(self, ShowLink::None) || matches!(self, ShowLink::Specific(v) if v.is_empty())
}
/// Resolve the list of user names to display, given all configured users.
pub fn resolve_users<'a>(&'a self, all_users: &'a HashMap<String, String>) -> Vec<&'a String> {
match self {
ShowLink::None => vec![],
ShowLink::All => {
let mut names: Vec<&String> = all_users.keys().collect();
names.sort();
names
}
ShowLink::Specific(names) => names.iter().collect(),
}
}
}
impl Serialize for ShowLink {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
match self {
ShowLink::None => Vec::<String>::new().serialize(serializer),
ShowLink::All => serializer.serialize_str("*"),
ShowLink::Specific(v) => v.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for ShowLink {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
use serde::de;
struct ShowLinkVisitor;
impl<'de> de::Visitor<'de> for ShowLinkVisitor {
type Value = ShowLink;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str(r#""*" or an array of user names"#)
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<ShowLink, E> {
if v == "*" {
Ok(ShowLink::All)
} else {
Err(de::Error::invalid_value(
de::Unexpected::Str(v),
&r#""*""#,
))
}
}
fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> std::result::Result<ShowLink, A::Error> {
let mut names = Vec::new();
while let Some(name) = seq.next_element::<String>()? {
names.push(name);
}
if names.is_empty() {
Ok(ShowLink::None)
} else {
Ok(ShowLink::Specific(names))
}
}
}
deserializer.deserialize_any(ShowLinkVisitor)
}
}

View File

@@ -1,9 +1,19 @@
//! AES encryption implementations
//!
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
//!
//! ## Zeroize policy
//!
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
//! The expanded key schedule lives inside that type and cannot be
//! zeroized from outside. Callers that hold raw key material (e.g.
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
//! zeroizing their own copies.
use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use zeroize::Zeroize;
use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>;
@@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE<Aes256>;
/// AES-256-CTR encryptor/decryptor
///
/// CTR mode is symmetric - encryption and decryption are the same operation.
/// CTR mode is symmetric encryption and decryption are the same operation.
///
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
/// + counter) is opaque and cannot be zeroized. If you need to protect key
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
/// before dropping them.
pub struct AesCtr {
cipher: Aes256Ctr,
}
@@ -62,14 +77,23 @@ impl AesCtr {
/// AES-256-CBC cipher with proper chaining
///
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption
/// Unlike CTR mode, CBC is NOT symmetric encryption and decryption
/// are different operations. This implementation handles CBC chaining
/// correctly across multiple blocks.
///
/// Key and IV are zeroized on drop.
pub struct AesCbc {
key: [u8; 32],
iv: [u8; 16],
}
impl Drop for AesCbc {
fn drop(&mut self) {
self.key.zeroize();
self.iv.zeroize();
}
}
impl AesCbc {
/// AES block size
const BLOCK_SIZE: usize = 16;
@@ -141,17 +165,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) {
let plaintext: [u8; 16] = chunk.try_into().unwrap();
// XOR plaintext with previous ciphertext (or IV for first block)
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
// Encrypt the XORed block
let ciphertext = self.encrypt_block(&xored, &key_schedule);
// Save for next iteration
prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&ciphertext);
}
@@ -180,17 +196,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap();
// Decrypt the block
let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
// XOR with previous ciphertext (or IV for first block)
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
// Save current ciphertext for next iteration
prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&plaintext);
}
@@ -217,16 +225,13 @@ impl AesCbc {
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
// Encrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule);
// Save for next iteration
prev_ciphertext = *block_array;
}
@@ -248,26 +253,20 @@ impl AesCbc {
use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into());
// For in-place decryption, we need to save ciphertext blocks
// before we overwrite them
let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE];
// Save current ciphertext before modifying
let current_ciphertext: [u8; 16] = block.try_into().unwrap();
// Decrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule);
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j];
}
// Save for next iteration
prev_ciphertext = current_ciphertext;
}
@@ -347,10 +346,8 @@ mod tests {
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
// Encrypted should be different
assert_ne!(&data[..], original);
// Decrypt with fresh cipher
let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data);
@@ -364,7 +361,7 @@ mod tests {
let key = [0u8; 32];
let iv = [0u8; 16];
let original = [0u8; 32]; // 2 blocks
let original = [0u8; 32];
let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap();
@@ -375,31 +372,25 @@ mod tests {
#[test]
fn test_aes_cbc_chaining_works() {
// This is the key test - verify CBC chaining is correct
let key = [0x42u8; 32];
let iv = [0x00u8; 16];
// Two IDENTICAL plaintext blocks
let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32];
assert_ne!(
block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \
This indicates ECB mode, not CBC!"
"CBC chaining broken: identical plaintext blocks produced identical ciphertext"
);
}
#[test]
fn test_aes_cbc_known_vector() {
// Test with known NIST test vector
// AES-256-CBC with zero key and zero IV
let key = [0u8; 32];
let iv = [0u8; 16];
let plaintext = [0u8; 16];
@@ -407,11 +398,9 @@ mod tests {
let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt and verify roundtrip
let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Ciphertext should not be all zeros
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
}
@@ -420,7 +409,6 @@ mod tests {
let key = [0x12u8; 32];
let iv = [0x34u8; 16];
// 5 blocks = 80 bytes
let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv);
@@ -435,7 +423,7 @@ mod tests {
let key = [0x12u8; 32];
let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks
let original = [0x56u8; 48];
let mut buffer = original;
let cipher = AesCbc::new(key, iv);
@@ -462,41 +450,33 @@ mod tests {
fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes - not aligned to block size
let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err());
// 17 bytes - not aligned
let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err());
}
#[test]
fn test_aes_cbc_avalanche_effect() {
// Changing one bit in plaintext should change entire ciphertext block
// and all subsequent blocks (due to chaining)
let key = [0xAB; 32];
let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32];
let plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Single bit difference in first block
plaintext2[0] = 0x01;
let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First blocks should be different
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second blocks should ALSO be different (chaining effect)
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
}
#[test]
fn test_aes_cbc_iv_matters() {
// Same plaintext with different IVs should produce different ciphertext
let key = [0x55; 32];
let plaintext = [0x77u8; 16];
@@ -511,7 +491,6 @@ mod tests {
#[test]
fn test_aes_cbc_deterministic() {
// Same key, IV, plaintext should always produce same ciphertext
let key = [0x99; 32];
let iv = [0x88; 16];
let plaintext = [0x77u8; 32];
@@ -524,6 +503,23 @@ mod tests {
assert_eq!(ciphertext1, ciphertext2);
}
// ============= Zeroize Tests =============
#[test]
fn test_aes_cbc_zeroize_on_drop() {
let key = [0xAA; 32];
let iv = [0xBB; 16];
let cipher = AesCbc::new(key, iv);
// Verify key/iv are set
assert_eq!(cipher.key, [0xAA; 32]);
assert_eq!(cipher.iv, [0xBB; 16]);
drop(cipher);
// After drop, key/iv are zeroized (can't observe directly,
// but the Drop impl runs without panic)
}
// ============= Error Handling Tests =============
#[test]

View File

@@ -1,3 +1,16 @@
//! Cryptographic hash functions
//!
//! ## Protocol-required algorithms
//!
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
//! hash functions are **required by the Telegram Middle Proxy protocol**
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
//! compatibility. They are NOT used for any security-sensitive purpose
//! outside of that specific key derivation scheme mandated by Telegram.
//!
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
//! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac};
use sha2::Sha256;
use md5::Md5;
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
mac.finalize().into_bytes().into()
}
/// SHA-1
/// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn sha1(data: &[u8]) -> [u8; 20] {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.finalize().into()
}
/// MD5
/// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn md5(data: &[u8]) -> [u8; 16] {
let mut hasher = Md5::new();
hasher.update(data);
@@ -40,7 +55,54 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data)
}
/// Middle Proxy Keygen
/// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
///
/// Returned buffer layout (IPv4):
/// nonce_srv | nonce_clt | clt_ts | srv_ip | clt_port | purpose | clt_ip | srv_port | secret | nonce_srv | [clt_v6 | srv_v6] | nonce_clt
pub fn build_middleproxy_prekey(
nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16],
clt_ts: &[u8; 4],
srv_ip: Option<&[u8]>,
clt_port: &[u8; 2],
purpose: &[u8],
clt_ip: Option<&[u8]>,
srv_port: &[u8; 2],
secret: &[u8],
clt_ipv6: Option<&[u8; 16]>,
srv_ipv6: Option<&[u8; 16]>,
) -> Vec<u8> {
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
let clt_ip = clt_ip.unwrap_or(&EMPTY_IP);
let mut s = Vec::with_capacity(256);
s.extend_from_slice(nonce_srv);
s.extend_from_slice(nonce_clt);
s.extend_from_slice(clt_ts);
s.extend_from_slice(srv_ip);
s.extend_from_slice(clt_port);
s.extend_from_slice(purpose);
s.extend_from_slice(clt_ip);
s.extend_from_slice(srv_port);
s.extend_from_slice(secret);
s.extend_from_slice(nonce_srv);
if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) {
s.extend_from_slice(clt_v6);
s.extend_from_slice(srv_v6);
}
s.extend_from_slice(nonce_clt);
s
}
/// Middle Proxy key derivation
///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
/// These algorithms are NOT replaceable here — changing them would break
/// interoperability with Telegram's middle proxy infrastructure.
pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16],
@@ -54,30 +116,20 @@ pub fn derive_middleproxy_keys(
clt_ipv6: Option<&[u8; 16]>,
srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) {
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
let clt_ip = clt_ip.unwrap_or(&EMPTY_IP);
let mut s = Vec::with_capacity(256);
s.extend_from_slice(nonce_srv);
s.extend_from_slice(nonce_clt);
s.extend_from_slice(clt_ts);
s.extend_from_slice(srv_ip);
s.extend_from_slice(clt_port);
s.extend_from_slice(purpose);
s.extend_from_slice(clt_ip);
s.extend_from_slice(srv_port);
s.extend_from_slice(secret);
s.extend_from_slice(nonce_srv);
if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) {
s.extend_from_slice(clt_v6);
s.extend_from_slice(srv_v6);
}
s.extend_from_slice(nonce_clt);
let s = build_middleproxy_prekey(
nonce_srv,
nonce_clt,
clt_ts,
srv_ip,
clt_port,
purpose,
clt_ip,
srv_port,
secret,
clt_ipv6,
srv_ipv6,
);
let md5_1 = md5(&s[1..]);
let sha1_sum = sha1(&s);
let md5_2 = md5(&s[2..]);
@@ -87,4 +139,40 @@ pub fn derive_middleproxy_keys(
key[12..].copy_from_slice(&sha1_sum);
(key, md5_2)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn middleproxy_prekey_sha_is_stable() {
let nonce_srv = [0x11u8; 16];
let nonce_clt = [0x22u8; 16];
let clt_ts = 0x44332211u32.to_le_bytes();
let srv_ip = Some([149u8, 154, 175, 50].as_ref());
let clt_ip = Some([10u8, 0, 0, 1].as_ref());
let clt_port = 0x1f90u16.to_le_bytes(); // 8080
let srv_port = 0x22b8u16.to_le_bytes(); // 8888
let secret = vec![0x55u8; 128];
let prekey = build_middleproxy_prekey(
&nonce_srv,
&nonce_clt,
&clt_ts,
srv_ip,
&clt_port,
b"CLIENT",
clt_ip,
&srv_port,
&secret,
None,
None,
);
let digest = sha256(&prekey);
assert_eq!(
hex::encode(digest),
"934f5facdafd65a44d5c2df90d2f35ddc81faaaeb337949dfeef817c8a7c1e00"
);
}
}

View File

@@ -5,5 +5,5 @@ pub mod hash;
pub mod random;
pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey};
pub use random::SecureRandom;

View File

@@ -3,11 +3,8 @@
use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom {
@@ -20,18 +17,30 @@ struct SecureRandomInner {
buffer: Vec<u8>,
}
impl Drop for SecureRandomInner {
fn drop(&mut self) {
self.buffer.zeroize();
}
}
impl SecureRandom {
pub fn new() -> Self {
let mut rng = StdRng::from_entropy();
let mut seed_source = rand::rng();
let mut rng = StdRng::from_rng(&mut seed_source);
let mut key = [0u8; 32];
rng.fill_bytes(&mut key);
let iv: u128 = rng.gen();
let iv: u128 = rng.random();
let cipher = AesCtr::new(&key, iv);
// Zeroize local key copy — cipher already consumed it
key.zeroize();
Self {
inner: Mutex::new(SecureRandomInner {
rng,
cipher: AesCtr::new(&key, iv),
cipher,
buffer: Vec::with_capacity(1024),
}),
}
@@ -78,7 +87,6 @@ impl SecureRandom {
result |= (b as u64) << (i * 8);
}
// Mask extra bits
if k < 64 {
result &= (1u64 << k) - 1;
}
@@ -107,13 +115,13 @@ impl SecureRandom {
/// Generate random u32
pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock();
inner.rng.gen()
inner.rng.random()
}
/// Generate random u64
pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock();
inner.rng.gen()
inner.rng.random()
}
}
@@ -162,12 +170,10 @@ mod tests {
fn test_bits() {
let rng = SecureRandom::new();
// Single bit should be 0 or 1
for _ in 0..100 {
assert!(rng.bits(1) <= 1);
}
// 8 bits should be 0-255
for _ in 0..100 {
assert!(rng.bits(8) <= 255);
}
@@ -185,10 +191,8 @@ mod tests {
}
}
// Should have seen all items
assert_eq!(seen.len(), 5);
// Empty slice should return None
let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none());
}
@@ -201,12 +205,10 @@ mod tests {
let mut shuffled = original.clone();
rng.shuffle(&mut shuffled);
// Should contain same elements
let mut sorted = shuffled.clone();
sorted.sort();
assert_eq!(sorted, original);
// Should be different order (with very high probability)
assert_ne!(shuffled, original);
}
}

View File

@@ -118,16 +118,13 @@ pub trait Recoverable {
impl Recoverable for StreamError {
fn is_recoverable(&self) -> bool {
match self {
// Partial operations can be retried
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
// I/O errors depend on kind
Self::Io(e) => matches!(
e.kind(),
std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut
),
// These are not recoverable
Self::Poisoned { .. }
| Self::BufferOverflow { .. }
| Self::InvalidFrame { .. }
@@ -137,13 +134,9 @@ impl Recoverable for StreamError {
fn can_continue(&self) -> bool {
match self {
// Poisoned stream cannot be used
Self::Poisoned { .. } => false,
// EOF means stream is done
Self::UnexpectedEof => false,
// Buffer overflow is fatal
Self::BufferOverflow { .. } => false,
// Others might allow continuation
_ => true,
}
}
@@ -383,18 +376,18 @@ mod tests {
#[test]
fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
assert!(success.is_success());
assert!(!success.is_bad_client());
let bad: HandshakeResult<i32> = HandshakeResult::BadClient;
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
assert!(!bad.is_success());
assert!(bad.is_bad_client());
}
#[test]
fn test_handshake_result_map() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42);
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2);
match mapped {

462
src/ip_tracker.rs Normal file
View File

@@ -0,0 +1,462 @@
// src/ip_tracker.rs
// Модуль для отслеживания и ограничения уникальных IP-адресов пользователей
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Трекер уникальных IP-адресов для каждого пользователя MTProxy
///
/// Предоставляет thread-safe механизм для:
/// - Отслеживания активных IP-адресов каждого пользователя
/// - Ограничения количества уникальных IP на пользователя
/// - Автоматической очистки при отключении клиентов
#[derive(Debug, Clone)]
pub struct UserIpTracker {
/// Маппинг: Имя пользователя -> Множество активных IP-адресов
active_ips: Arc<RwLock<HashMap<String, HashSet<IpAddr>>>>,
/// Маппинг: Имя пользователя -> Максимально разрешенное количество уникальных IP
max_ips: Arc<RwLock<HashMap<String, usize>>>,
}
impl UserIpTracker {
/// Создать новый пустой трекер
pub fn new() -> Self {
Self {
active_ips: Arc::new(RwLock::new(HashMap::new())),
max_ips: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Установить лимит уникальных IP для конкретного пользователя
///
/// # Arguments
/// * `username` - Имя пользователя
/// * `max_ips` - Максимальное количество одновременно активных IP-адресов
pub async fn set_user_limit(&self, username: &str, max_ips: usize) {
let mut limits = self.max_ips.write().await;
limits.insert(username.to_string(), max_ips);
}
/// Загрузить лимиты из конфигурации
///
/// # Arguments
/// * `limits` - HashMap с лимитами из config.toml
pub async fn load_limits(&self, limits: &HashMap<String, usize>) {
let mut max_ips = self.max_ips.write().await;
for (user, limit) in limits {
max_ips.insert(user.clone(), *limit);
}
}
/// Проверить, может ли пользователь подключиться с данного IP-адреса
/// и добавить IP в список активных, если проверка успешна
///
/// # Arguments
/// * `username` - Имя пользователя
/// * `ip` - IP-адрес клиента
///
/// # Returns
/// * `Ok(())` - Подключение разрешено, IP добавлен в активные
/// * `Err(String)` - Подключение отклонено с описанием причины
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
// Получаем лимит для пользователя
let max_ips = self.max_ips.read().await;
let limit = match max_ips.get(username) {
Some(limit) => *limit,
None => {
// Если лимит не задан - разрешаем безлимитный доступ
drop(max_ips);
let mut active_ips = self.active_ips.write().await;
let user_ips = active_ips
.entry(username.to_string())
.or_insert_with(HashSet::new);
user_ips.insert(ip);
return Ok(());
}
};
drop(max_ips);
// Проверяем и обновляем активные IP
let mut active_ips = self.active_ips.write().await;
let user_ips = active_ips
.entry(username.to_string())
.or_insert_with(HashSet::new);
// Если IP уже есть в списке - это повторное подключение, разрешаем
if user_ips.contains(&ip) {
return Ok(());
}
// Проверяем, не превышен ли лимит
if user_ips.len() >= limit {
return Err(format!(
"IP limit reached for user '{}': {}/{} unique IPs already connected",
username,
user_ips.len(),
limit
));
}
// Лимит не превышен - добавляем новый IP
user_ips.insert(ip);
Ok(())
}
/// Удалить IP-адрес из списка активных при отключении клиента
///
/// # Arguments
/// * `username` - Имя пользователя
/// * `ip` - IP-адрес отключившегося клиента
pub async fn remove_ip(&self, username: &str, ip: IpAddr) {
let mut active_ips = self.active_ips.write().await;
if let Some(user_ips) = active_ips.get_mut(username) {
user_ips.remove(&ip);
// Если у пользователя не осталось активных IP - удаляем запись
// для экономии памяти
if user_ips.is_empty() {
active_ips.remove(username);
}
}
}
/// Получить текущее количество активных IP-адресов для пользователя
///
/// # Arguments
/// * `username` - Имя пользователя
///
/// # Returns
/// Количество уникальных активных IP-адресов
pub async fn get_active_ip_count(&self, username: &str) -> usize {
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
.map(|ips| ips.len())
.unwrap_or(0)
}
/// Получить список всех активных IP-адресов для пользователя
///
/// # Arguments
/// * `username` - Имя пользователя
///
/// # Returns
/// Вектор с активными IP-адресами
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
.map(|ips| ips.iter().copied().collect())
.unwrap_or_else(Vec::new)
}
/// Получить статистику по всем пользователям
///
/// # Returns
/// Вектор кортежей: (имя_пользователя, количество_активных_IP, лимит)
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
let active_ips = self.active_ips.read().await;
let max_ips = self.max_ips.read().await;
let mut stats = Vec::new();
// Собираем статистику по пользователям с активными подключениями
for (username, user_ips) in active_ips.iter() {
let limit = max_ips.get(username).copied().unwrap_or(0);
stats.push((username.clone(), user_ips.len(), limit));
}
stats.sort_by(|a, b| a.0.cmp(&b.0)); // Сортируем по имени пользователя
stats
}
/// Очистить все активные IP для пользователя (при необходимости)
///
/// # Arguments
/// * `username` - Имя пользователя
pub async fn clear_user_ips(&self, username: &str) {
let mut active_ips = self.active_ips.write().await;
active_ips.remove(username);
}
/// Очистить всю статистику (использовать с осторожностью!)
pub async fn clear_all(&self) {
let mut active_ips = self.active_ips.write().await;
active_ips.clear();
}
/// Проверить, подключен ли пользователь с данного IP
///
/// # Arguments
/// * `username` - Имя пользователя
/// * `ip` - IP-адрес для проверки
///
/// # Returns
/// `true` если IP активен, `false` если нет
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
.map(|ips| ips.contains(&ip))
.unwrap_or(false)
}
/// Получить лимит для пользователя
///
/// # Arguments
/// * `username` - Имя пользователя
///
/// # Returns
/// Лимит IP-адресов или None, если лимит не установлен
pub async fn get_user_limit(&self, username: &str) -> Option<usize> {
let max_ips = self.max_ips.read().await;
max_ips.get(username).copied()
}
/// Форматировать статистику в читаемый текст
///
/// # Returns
/// Строка со статистикой для логов или мониторинга
pub async fn format_stats(&self) -> String {
let stats = self.get_stats().await;
if stats.is_empty() {
return String::from("No active users");
}
let mut output = String::from("User IP Statistics:\n");
output.push_str("==================\n");
for (username, active_count, limit) in stats {
output.push_str(&format!(
"User: {:<20} Active IPs: {}/{}\n",
username,
active_count,
if limit > 0 { limit.to_string() } else { "unlimited".to_string() }
));
let ips = self.get_active_ips(&username).await;
for ip in ips {
output.push_str(&format!(" └─ {}\n", ip));
}
}
output
}
}
impl Default for UserIpTracker {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// ТЕСТЫ
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
fn test_ipv4(oct1: u8, oct2: u8, oct3: u8, oct4: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(oct1, oct2, oct3, oct4))
}
fn test_ipv6() -> IpAddr {
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1))
}
#[tokio::test]
async fn test_basic_ip_limit() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("test_user", 2).await;
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
let ip3 = test_ipv4(192, 168, 1, 3);
// Первые два IP должны быть приняты
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
assert!(tracker.check_and_add("test_user", ip2).await.is_ok());
// Третий IP должен быть отклонен
assert!(tracker.check_and_add("test_user", ip3).await.is_err());
// Проверяем счетчик
assert_eq!(tracker.get_active_ip_count("test_user").await, 2);
}
#[tokio::test]
async fn test_reconnection_from_same_ip() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("test_user", 2).await;
let ip1 = test_ipv4(192, 168, 1, 1);
// Первое подключение
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
// Повторное подключение с того же IP должно пройти
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
// Счетчик не должен увеличиться
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
}
#[tokio::test]
async fn test_ip_removal() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("test_user", 2).await;
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
let ip3 = test_ipv4(192, 168, 1, 3);
// Добавляем два IP
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
assert!(tracker.check_and_add("test_user", ip2).await.is_ok());
// Третий не должен пройти
assert!(tracker.check_and_add("test_user", ip3).await.is_err());
// Удаляем первый IP
tracker.remove_ip("test_user", ip1).await;
// Теперь третий должен пройти
assert!(tracker.check_and_add("test_user", ip3).await.is_ok());
assert_eq!(tracker.get_active_ip_count("test_user").await, 2);
}
#[tokio::test]
async fn test_no_limit() {
let tracker = UserIpTracker::new();
// Не устанавливаем лимит для test_user
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
let ip3 = test_ipv4(192, 168, 1, 3);
// Без лимита все IP должны проходить
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
assert!(tracker.check_and_add("test_user", ip2).await.is_ok());
assert!(tracker.check_and_add("test_user", ip3).await.is_ok());
assert_eq!(tracker.get_active_ip_count("test_user").await, 3);
}
#[tokio::test]
async fn test_multiple_users() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user1", 2).await;
tracker.set_user_limit("user2", 1).await;
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
// user1 может использовать 2 IP
assert!(tracker.check_and_add("user1", ip1).await.is_ok());
assert!(tracker.check_and_add("user1", ip2).await.is_ok());
// user2 может использовать только 1 IP
assert!(tracker.check_and_add("user2", ip1).await.is_ok());
assert!(tracker.check_and_add("user2", ip2).await.is_err());
}
#[tokio::test]
async fn test_ipv6_support() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("test_user", 2).await;
let ipv4 = test_ipv4(192, 168, 1, 1);
let ipv6 = test_ipv6();
// Должны работать оба типа адресов
assert!(tracker.check_and_add("test_user", ipv4).await.is_ok());
assert!(tracker.check_and_add("test_user", ipv6).await.is_ok());
assert_eq!(tracker.get_active_ip_count("test_user").await, 2);
}
#[tokio::test]
async fn test_get_active_ips() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("test_user", 3).await;
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
tracker.check_and_add("test_user", ip1).await.unwrap();
tracker.check_and_add("test_user", ip2).await.unwrap();
let active_ips = tracker.get_active_ips("test_user").await;
assert_eq!(active_ips.len(), 2);
assert!(active_ips.contains(&ip1));
assert!(active_ips.contains(&ip2));
}
#[tokio::test]
async fn test_stats() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user1", 3).await;
tracker.set_user_limit("user2", 2).await;
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
tracker.check_and_add("user1", ip1).await.unwrap();
tracker.check_and_add("user2", ip2).await.unwrap();
let stats = tracker.get_stats().await;
assert_eq!(stats.len(), 2);
// Проверяем наличие обоих пользователей в статистике
assert!(stats.iter().any(|(name, _, _)| name == "user1"));
assert!(stats.iter().any(|(name, _, _)| name == "user2"));
}
#[tokio::test]
async fn test_clear_user_ips() {
let tracker = UserIpTracker::new();
let ip1 = test_ipv4(192, 168, 1, 1);
tracker.check_and_add("test_user", ip1).await.unwrap();
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
tracker.clear_user_ips("test_user").await;
assert_eq!(tracker.get_active_ip_count("test_user").await, 0);
}
#[tokio::test]
async fn test_is_ip_active() {
let tracker = UserIpTracker::new();
let ip1 = test_ipv4(192, 168, 1, 1);
let ip2 = test_ipv4(192, 168, 1, 2);
tracker.check_and_add("test_user", ip1).await.unwrap();
assert!(tracker.is_ip_active("test_user", ip1).await);
assert!(!tracker.is_ip_active("test_user", ip2).await);
}
#[tokio::test]
async fn test_load_limits_from_config() {
let tracker = UserIpTracker::new();
let mut config_limits = HashMap::new();
config_limits.insert("user1".to_string(), 5);
config_limits.insert("user2".to_string(), 3);
tracker.load_limits(&config_limits).await;
assert_eq!(tracker.get_user_limit("user1").await, Some(5));
assert_eq!(tracker.get_user_limit("user2").await, Some(3));
assert_eq!(tracker.get_user_limit("user3").await, None);
}
}

View File

@@ -1,16 +1,23 @@
//! Telemt - MTProxy on Rust
//! telemt — Telegram MTProto Proxy
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::signal;
use tracing::{info, error, warn};
use tracing_subscriber::{fmt, EnvFilter};
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload};
#[cfg(unix)]
use tokio::net::UnixListener;
mod cli;
mod config;
mod crypto;
mod error;
mod ip_tracker;
mod network;
mod metrics;
mod protocol;
mod proxy;
mod stats;
@@ -18,155 +25,746 @@ mod stream;
mod transport;
mod util;
use crate::config::ProxyConfig;
use crate::config::{LogLevel, ProxyConfig};
use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip;
use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool;
use crate::transport::middle_proxy::{
MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line,
};
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
let mut silent = false;
let mut log_level: Option<String> = None;
let args: Vec<String> = std::env::args().skip(1).collect();
// Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) {
if let Err(e) = cli::run_init(init_opts) {
eprintln!("[telemt] Init failed: {}", e);
std::process::exit(1);
}
std::process::exit(0);
}
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => {
silent = true;
}
"--log-level" => {
i += 1;
if i < args.len() {
log_level = Some(args[i].clone());
}
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(
" --init Generate config, install systemd service, start"
);
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(
" --domain <DOMAIN> TLS domain for masking (default: www.google.com)"
);
eprintln!(
" --secret <HEX> 32-char hex secret (auto-generated if omitted)"
);
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => {
config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
}
i += 1;
}
(config_path, silent, log_level)
}
fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!("--- Proxy Links ({}) ---", host);
for user_name in config.general.links.show.resolve_users(&config.access.users) {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(
" Classic: tg://proxy?server={}&port={}&secret={}",
host, port, secret
);
}
if config.general.modes.secure {
info!(
" DD: tg://proxy?server={}&port={}&secret=dd{}",
host, port, secret
);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
host, port, secret, domain_hex
);
}
} else {
warn!("User '{}' in show_link not found", user_name);
}
}
info!("------------------------");
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
.init();
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (config_path, cli_silent, cli_log_level) = parse_cli();
// Load config
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
let config = match ProxyConfig::load(&config_path) {
let mut config = match ProxyConfig::load(&config_path) {
Ok(c) => c,
Err(e) => {
// If config doesn't exist, try to create default
if std::path::Path::new(&config_path).exists() {
error!("Failed to load config: {}", e);
eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
} else {
let default = ProxyConfig::default();
let toml = toml::to_string_pretty(&default).unwrap();
std::fs::write(&config_path, toml).unwrap();
info!("Created default config at {}", config_path);
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
default
}
}
};
if let Err(e) = config.validate() {
eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
let has_rust_log = std::env::var("RUST_LOG").is_ok();
let effective_log_level = if cli_silent {
LogLevel::Silent
} else if let Some(ref s) = cli_log_level {
LogLevel::from_str_loose(s)
} else {
config.general.log_level.clone()
};
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
config.validate()?;
// Configure color output based on config
let fmt_layer = if config.general.disable_colors {
fmt::Layer::default().with_ansi(false)
} else {
fmt::Layer::default().with_ansi(true)
};
// Log loaded configuration for debugging
info!("=== Configuration Loaded ===");
info!("TLS Domain: {}", config.censorship.tls_domain);
info!("Mask enabled: {}", config.censorship.mask);
info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain));
info!("Mask port: {}", config.censorship.mask_port);
info!("Modes: classic={}, secure={}, tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
if config.general.disable_colors {
info!("Colors: disabled");
}
info!(
"Modes: classic={} secure={} tls={}",
config.general.modes.classic, config.general.modes.secure, config.general.modes.tls
);
info!("============================");
let config = Arc::new(config);
info!("TLS domain: {}", config.censorship.tls_domain);
if let Some(ref sock) = config.censorship.mask_unix_sock {
info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
if !std::path::Path::new(sock).exists() {
warn!(
"Unix socket '{}' does not exist yet. Masking will fail until it appears.",
sock
);
}
} else {
info!(
"Mask: {} -> {}:{}",
config.censorship.mask,
config
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port
);
}
if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain. Consider setting a custom domain.");
}
let probe = run_probe(
&config.network,
config.general.middle_proxy_nat_stun.clone(),
config.general.middle_proxy_nat_probe,
)
.await?;
let decision = decide_network_capabilities(&config.network, &probe);
log_probe_result(&probe, &decision);
let prefer_ipv6 = decision.prefer_ipv6();
let mut use_middle_proxy = config.general.use_middle_proxy && (decision.ipv4_me || decision.ipv6_me);
let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
// IP Tracker initialization
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.load_limits(&config.access.user_max_unique_ips).await;
// Initialize global ReplayChecker
// Using sharded implementation for better concurrency
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
// Initialize Upstream Manager
if !config.access.user_max_unique_ips.is_empty() {
info!("IP limits configured for {} users", config.access.user_max_unique_ips.len());
}
// Connection concurrency limit
let _max_connections = Arc::new(Semaphore::new(10_000));
if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me {
warn!("No usable IP family for Middle Proxy detected; falling back to direct DC");
use_middle_proxy = false;
}
// =====================================================================
// Middle Proxy initialization (if enabled)
// =====================================================================
let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!("=== Middle Proxy Mode ===");
// ad_tag (proxy_tag) for advertising
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
hex::decode(tag).unwrap_or_else(|_| {
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
Vec::new()
})
});
// =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
//
// C MTProxy uses TWO separate secrets:
// -S flag = 16-byte user secret for client obfuscation
// --aes-pwd = 32-512 byte binary file for ME RPC auth
//
// proxy-secret is from: https://core.telegram.org/getProxySecret
// =============================================================
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
Ok(proxy_secret) => {
info!(
secret_len = proxy_secret.len() as usize, // ← ЯВНЫЙ ТИП usize
key_sig = format_args!(
"0x{:08x}",
if proxy_secret.len() >= 4 {
u32::from_le_bytes([
proxy_secret[0],
proxy_secret[1],
proxy_secret[2],
proxy_secret[3],
])
} else {
0
}
),
"Proxy-secret loaded"
);
// Load ME config (v4/v6) + default DC
let mut cfg_v4 = fetch_proxy_config(
"https://core.telegram.org/getProxyConfig",
)
.await
.unwrap_or_default();
let mut cfg_v6 = fetch_proxy_config(
"https://core.telegram.org/getProxyConfigV6",
)
.await
.unwrap_or_default();
if cfg_v4.map.is_empty() {
cfg_v4.map = crate::protocol::constants::TG_MIDDLE_PROXIES_V4.clone();
}
if cfg_v6.map.is_empty() {
cfg_v6.map = crate::protocol::constants::TG_MIDDLE_PROXIES_V6.clone();
}
let pool = MePool::new(
proxy_tag,
proxy_secret,
config.general.middle_proxy_nat_ip,
config.general.middle_proxy_nat_probe,
config.general.middle_proxy_nat_stun.clone(),
probe.detected_ipv6,
config.timeouts.me_one_retry,
config.timeouts.me_one_timeout_ms,
cfg_v4.map.clone(),
cfg_v6.map.clone(),
cfg_v4.default_dc.or(cfg_v6.default_dc),
decision.clone(),
rng.clone(),
);
match pool.init(2, &rng).await {
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2,
)
.await;
});
// Periodic ME connection rotation
let pool_clone_rot = pool.clone();
let rng_clone_rot = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_rotation_task(
pool_clone_rot,
rng_clone_rot,
std::time::Duration::from_secs(1800),
)
.await;
});
// Periodic updater: getProxyConfig + proxy-secret
let pool_clone2 = pool.clone();
let rng_clone2 = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_config_updater(
pool_clone2,
rng_clone2,
std::time::Duration::from_secs(12 * 3600),
)
.await;
});
Some(pool)
}
Err(e) => {
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
None
}
}
}
Err(e) => {
error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
None
}
}
} else {
None
};
// If ME failed to initialize, force direct-only mode.
if me_pool.is_some() {
info!("Transport: Middle-End Proxy - all DC-over-RPC");
} else {
use_middle_proxy = false;
// Make runtime config reflect direct-only mode for handlers.
config.general.use_middle_proxy = false;
info!("Transport: Direct DC - TCP - standard DC-over-TCP");
}
// Freeze config after possible fallback decision
let config = Arc::new(config);
let replay_checker = Arc::new(ReplayChecker::new(
config.access.replay_check_len,
Duration::from_secs(config.access.replay_window_secs),
));
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
// Initialize Buffer Pool
// 16KB buffers, max 4096 buffers (~64MB total cached)
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start Health Checks
// Middle-End ping before DC connectivity
if let Some(ref pool) = me_pool {
let me_results = run_me_ping(pool, &rng).await;
let v4_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V4)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
});
let v6_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V6)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
});
info!("================= Telegram ME Connectivity =================");
if v4_ok && v6_ok {
info!(" IPv4 and IPv6 available");
} else if v4_ok {
info!(" IPv4 only / IPv6 unavailable");
} else if v6_ok {
info!(" IPv6 only / IPv4 unavailable");
} else {
info!(" No ME connectivity");
}
info!(" via direct");
info!("============================================================");
use std::collections::BTreeMap;
let mut grouped: BTreeMap<i32, Vec<MePingSample>> = BTreeMap::new();
for report in me_results {
for s in report.samples {
let key = s.dc.abs();
grouped.entry(key).or_default().push(s);
}
}
let family_order = if prefer_ipv6 {
vec![(MePingFamily::V6, true), (MePingFamily::V6, false), (MePingFamily::V4, true), (MePingFamily::V4, false)]
} else {
vec![(MePingFamily::V4, true), (MePingFamily::V4, false), (MePingFamily::V6, true), (MePingFamily::V6, false)]
};
for (dc_abs, samples) in grouped {
for (family, is_pos) in &family_order {
let fam_samples: Vec<&MePingSample> = samples
.iter()
.filter(|s| matches!(s.family, f if &f == family) && (s.dc >= 0) == *is_pos)
.collect();
if fam_samples.is_empty() {
continue;
}
let fam_label = match family {
MePingFamily::V4 => "IPv4",
MePingFamily::V6 => "IPv6",
};
info!(" DC{} [{}]", dc_abs, fam_label);
for sample in fam_samples {
let line = format_sample_line(sample);
info!("{}", line);
}
}
}
info!("============================================================");
}
info!("================= Telegram DC Connectivity =================");
let ping_results = upstream_manager
.ping_all_dcs(
prefer_ipv6,
&config.dc_overrides,
decision.ipv4_dc,
decision.ipv6_dc,
)
.await;
for upstream_result in &ping_results {
let v6_works = upstream_result
.v6_results
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if upstream_result.both_available {
if prefer_ipv6 {
info!(" IPv6 in use / IPv4 is fallback");
} else {
info!(" IPv4 in use / IPv6 is fallback");
}
} else {
if v6_works && !v4_works {
info!(" IPv6 only / IPv4 unavailable)");
} else if v4_works && !v6_works {
info!(" IPv4 only / IPv6 unavailable)");
} else if !v6_works && !v4_works {
info!(" No DC connectivity");
}
}
info!(" via {}", upstream_result.upstream_name);
info!("============================================================");
// Print IPv6 results first (only if IPv6 is available)
if v6_works {
for dc in &upstream_result.v6_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
info!(" DC{} [IPv6] {} - {:.0} ms", dc.dc_idx, addr_str, rtt);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(" DC{} [IPv6] {} - FAIL ({})", dc.dc_idx, addr_str, err);
}
}
}
info!("============================================================");
}
// Print IPv4 results (only if IPv4 is available)
if v4_works {
for dc in &upstream_result.v4_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
info!(
" DC{} [IPv4] {}\t\t\t\t{:.0} ms",
dc.dc_idx, addr_str, rtt
);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(
" DC{} [IPv4] {}:\t\t\t\tFAIL ({})",
dc.dc_idx, addr_str, err
);
}
}
}
info!("============================================================");
}
}
// Background tasks
let um_clone = upstream_manager.clone();
let decision_clone = decision.clone();
tokio::spawn(async move {
um_clone.run_health_checks().await;
um_clone
.run_health_checks(
prefer_ipv6,
decision_clone.ipv4_dc,
decision_clone.ipv6_dc,
)
.await;
});
// Detect public IP if needed (once at startup)
let detected_ip = detect_ip().await;
let rc_clone = replay_checker.clone();
tokio::spawn(async move {
rc_clone.run_periodic_cleanup().await;
});
let detected_ip_v4: Option<std::net::IpAddr> = probe
.reflected_ipv4
.map(|s| s.ip())
.or_else(|| probe.detected_ipv4.map(std::net::IpAddr::V4));
let detected_ip_v6: Option<std::net::IpAddr> = probe
.reflected_ipv6
.map(|s| s.ip())
.or_else(|| probe.detected_ipv6.map(std::net::IpAddr::V6));
debug!(
"Detected IPs: v4={:?} v6={:?}",
detected_ip_v4, detected_ip_v6
);
// Start Listeners
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
if addr.is_ipv4() && !decision.ipv4_dc {
warn!(%addr, "Skipping IPv4 listener: IPv4 disabled by [network]");
continue;
}
if addr.is_ipv6() && !decision.ipv6_dc {
warn!(%addr, "Skipping IPv6 listener: IPv6 disabled by [network]");
continue;
}
let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default()
};
match create_listener(addr, &options) {
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
// Determine public IP for tg:// links
let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip
// Resolve the public host for link generation
let public_host = if let Some(ref announce) = listener_conf.announce {
announce.clone() // Use announce (IP or hostname) if explicitly set
} else if listener_conf.ip.is_unspecified() {
// Auto-detect for unspecified addresses
if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip)
detected_ip_v4
.map(|ip| ip.to_string())
.unwrap_or_else(|| listener_conf.ip.to_string())
} else {
detected_ip.ipv6.unwrap_or(listener_conf.ip)
detected_ip_v6
.map(|ip| ip.to_string())
.unwrap_or_else(|| listener_conf.ip.to_string())
}
} else {
listener_conf.ip
listener_conf.ip.to_string()
};
// Show links for configured users
if !config.show_link.is_empty() {
info!("--- Proxy Links for {} ---", public_ip);
for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret);
}
if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex);
}
} else {
warn!("User '{}' specified in show_link not found in users list", user_name);
}
}
info!("-----------------------------------");
// Show per-listener proxy links only when public_host is not set
if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() {
let link_port = config.general.links.public_port.unwrap_or(config.server.port);
print_proxy_links(&public_host, link_port, &config);
}
listeners.push(listener);
},
}
Err(e) => {
error!("Failed to bind to {}: {}", addr, e);
}
}
}
if listeners.is_empty() {
error!("No listeners could be started. Exiting.");
// Show proxy links once when public_host is set, OR when there are no TCP listeners
// (unix-only mode) — use detected IP as fallback
if !config.general.links.show.is_empty() && (config.general.links.public_host.is_some() || listeners.is_empty()) {
let (host, port) = if let Some(ref h) = config.general.links.public_host {
(h.clone(), config.general.links.public_port.unwrap_or(config.server.port))
} else {
let ip = detected_ip_v4
.or(detected_ip_v6)
.map(|ip| ip.to_string());
if ip.is_none() {
warn!("show_link is configured but public IP could not be detected. Set public_host in config.");
}
(ip.unwrap_or_else(|| "UNKNOWN".to_string()), config.general.links.public_port.unwrap_or(config.server.port))
};
print_proxy_links(&host, port, &config);
}
// Unix socket setup (before listeners check so unix-only config works)
let mut has_unix_listener = false;
#[cfg(unix)]
if let Some(ref unix_path) = config.server.listen_unix_sock {
// Remove stale socket file if present (standard practice)
let _ = tokio::fs::remove_file(unix_path).await;
let unix_listener = UnixListener::bind(unix_path)?;
// Apply socket permissions if configured
if let Some(ref perm_str) = config.server.listen_unix_sock_perm {
match u32::from_str_radix(perm_str.trim_start_matches('0'), 8) {
Ok(mode) => {
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(mode);
if let Err(e) = std::fs::set_permissions(unix_path, perms) {
error!("Failed to set unix socket permissions to {}: {}", perm_str, e);
} else {
info!("Listening on unix:{} (mode {})", unix_path, perm_str);
}
}
Err(e) => {
warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e);
info!("Listening on unix:{}", unix_path);
}
}
} else {
info!("Listening on unix:{}", unix_path);
}
has_unix_listener = true;
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
tokio::spawn(async move {
let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
loop {
match unix_listener.accept().await {
Ok((stream, _)) => {
let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16));
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
tokio::spawn(async move {
if let Err(e) = crate::proxy::client::handle_client_stream(
stream, fake_peer, config, stats,
upstream_manager, replay_checker, buffer_pool, rng,
me_pool, ip_tracker,
).await {
debug!(error = %e, "Unix socket connection error");
}
});
}
Err(e) => {
error!("Unix socket accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
});
}
if listeners.is_empty() && !has_unix_listener {
error!("No listeners. Exiting.");
std::process::exit(1);
}
// Accept loop
// Switch to user-configured log level after startup
let runtime_filter = if has_rust_log {
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
filter_handle
.reload(runtime_filter)
.expect("Failed to switch log filter");
if let Some(port) = config.server.metrics_port {
let stats = stats.clone();
let whitelist = config.server.metrics_whitelist.clone();
tokio::spawn(async move {
metrics::serve(port, stats, whitelist).await;
});
}
for listener in listeners {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
tokio::spawn(async move {
loop {
match listener.accept().await {
@@ -176,18 +774,27 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
tokio::spawn(async move {
if let Err(e) = ClientHandler::new(
stream,
peer_addr,
config,
stream,
peer_addr,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool
).run().await {
// Log only relevant errors
buffer_pool,
rng,
me_pool,
ip_tracker,
)
.run()
.await
{
debug!(peer = %peer_addr, error = %e, "Connection error");
}
});
}
@@ -200,11 +807,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
}
// Wait for signal
match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e),
}
Ok(())
}
}

197
src/metrics.rs Normal file
View File

@@ -0,0 +1,197 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use http_body_util::Full;
use hyper::body::Bytes;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use tokio::net::TcpListener;
use tracing::{info, warn, debug};
use crate::stats::Stats;
pub async fn serve(port: u16, stats: Arc<Stats>, whitelist: Vec<IpAddr>) {
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) => {
warn!(error = %e, "Failed to bind metrics on {}", addr);
return;
}
};
info!("Metrics endpoint: http://{}/metrics", addr);
loop {
let (stream, peer) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
warn!(error = %e, "Metrics accept error");
continue;
}
};
if !whitelist.is_empty() && !whitelist.contains(&peer.ip()) {
debug!(peer = %peer, "Metrics request denied by whitelist");
continue;
}
let stats = stats.clone();
tokio::spawn(async move {
let svc = service_fn(move |req| {
let stats = stats.clone();
async move { handle(req, &stats) }
});
if let Err(e) = http1::Builder::new()
.serve_connection(hyper_util::rt::TokioIo::new(stream), svc)
.await
{
debug!(error = %e, "Metrics connection error");
}
});
}
}
fn handle(req: Request<hyper::body::Incoming>, stats: &Stats) -> Result<Response<Full<Bytes>>, Infallible> {
if req.uri().path() != "/metrics" {
let resp = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found\n")))
.unwrap();
return Ok(resp);
}
let body = render_metrics(stats);
let resp = Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain; version=0.0.4; charset=utf-8")
.body(Full::new(Bytes::from(body)))
.unwrap();
Ok(resp)
}
fn render_metrics(stats: &Stats) -> String {
use std::fmt::Write;
let mut out = String::with_capacity(4096);
let _ = writeln!(out, "# HELP telemt_uptime_seconds Proxy uptime");
let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge");
let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs());
let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections");
let _ = writeln!(out, "# TYPE telemt_connections_total counter");
let _ = writeln!(out, "telemt_connections_total {}", stats.get_connects_all());
let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections");
let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter");
let _ = writeln!(out, "telemt_connections_bad_total {}", stats.get_connects_bad());
let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts");
let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter");
let _ = writeln!(out, "telemt_handshake_timeouts_total {}", stats.get_handshake_timeouts());
let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections");
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections");
let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge");
let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received");
let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter");
let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent");
let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter");
let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received");
let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter");
let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent");
let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter");
for entry in stats.iter_user_stats() {
let user = entry.key();
let s = entry.value();
let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_render_metrics_format() {
let stats = Arc::new(Stats::new());
stats.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_bad();
stats.increment_handshake_timeouts();
stats.increment_user_connects("alice");
stats.increment_user_curr_connects("alice");
stats.add_user_octets_from("alice", 1024);
stats.add_user_octets_to("alice", 2048);
stats.increment_user_msgs_from("alice");
stats.increment_user_msgs_to("alice");
stats.increment_user_msgs_to("alice");
let output = render_metrics(&stats);
assert!(output.contains("telemt_connections_total 2"));
assert!(output.contains("telemt_connections_bad_total 1"));
assert!(output.contains("telemt_handshake_timeouts_total 1"));
assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024"));
assert!(output.contains("telemt_user_octets_to_client{user=\"alice\"} 2048"));
assert!(output.contains("telemt_user_msgs_from_client{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_msgs_to_client{user=\"alice\"} 2"));
}
#[test]
fn test_render_empty_stats() {
let stats = Stats::new();
let output = render_metrics(&stats);
assert!(output.contains("telemt_connections_total 0"));
assert!(output.contains("telemt_connections_bad_total 0"));
assert!(output.contains("telemt_handshake_timeouts_total 0"));
assert!(!output.contains("user="));
}
#[test]
fn test_render_has_type_annotations() {
let stats = Stats::new();
let output = render_metrics(&stats);
assert!(output.contains("# TYPE telemt_uptime_seconds gauge"));
assert!(output.contains("# TYPE telemt_connections_total counter"));
assert!(output.contains("# TYPE telemt_connections_bad_total counter"));
assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter"));
}
#[tokio::test]
async fn test_endpoint_integration() {
let stats = Arc::new(Stats::new());
stats.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_all();
let port = 19091u16;
let s = stats.clone();
tokio::spawn(async move {
serve(port, s, vec![]).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let resp = reqwest::get(format!("http://127.0.0.1:{}/metrics", port))
.await.unwrap();
assert_eq!(resp.status(), 200);
let body = resp.text().await.unwrap();
assert!(body.contains("telemt_connections_total 3"));
let resp404 = reqwest::get(format!("http://127.0.0.1:{}/other", port))
.await.unwrap();
assert_eq!(resp404.status(), 404);
}
}

4
src/network/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod probe;
pub mod stun;
pub use stun::IpFamily;

231
src/network/probe.rs Normal file
View File

@@ -0,0 +1,231 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
use tracing::{info, warn};
use crate::config::NetworkConfig;
use crate::error::Result;
use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily};
#[derive(Debug, Clone, Default)]
pub struct NetworkProbe {
pub detected_ipv4: Option<Ipv4Addr>,
pub detected_ipv6: Option<Ipv6Addr>,
pub reflected_ipv4: Option<SocketAddr>,
pub reflected_ipv6: Option<SocketAddr>,
pub ipv4_is_bogon: bool,
pub ipv6_is_bogon: bool,
pub ipv4_nat_detected: bool,
pub ipv6_nat_detected: bool,
pub ipv4_usable: bool,
pub ipv6_usable: bool,
}
#[derive(Debug, Clone, Default)]
pub struct NetworkDecision {
pub ipv4_dc: bool,
pub ipv6_dc: bool,
pub ipv4_me: bool,
pub ipv6_me: bool,
pub effective_prefer: u8,
pub effective_multipath: bool,
}
impl NetworkDecision {
pub fn prefer_ipv6(&self) -> bool {
self.effective_prefer == 6
}
pub fn me_families(&self) -> Vec<IpFamily> {
let mut res = Vec::new();
if self.ipv4_me {
res.push(IpFamily::V4);
}
if self.ipv6_me {
res.push(IpFamily::V6);
}
res
}
}
pub async fn run_probe(config: &NetworkConfig, stun_addr: Option<String>, nat_probe: bool) -> Result<NetworkProbe> {
let mut probe = NetworkProbe::default();
probe.detected_ipv4 = detect_local_ip_v4();
probe.detected_ipv6 = detect_local_ip_v6();
probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false);
probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false);
let stun_server = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
let stun_res = if nat_probe {
match stun_probe_dual(&stun_server).await {
Ok(res) => res,
Err(e) => {
warn!(error = %e, "STUN probe failed, continuing without reflection");
DualStunResult::default()
}
}
} else {
DualStunResult::default()
};
probe.reflected_ipv4 = stun_res.v4.map(|r| r.reflected_addr);
probe.reflected_ipv6 = stun_res.v6.map(|r| r.reflected_addr);
probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) {
(Some(det), Some(reflected)) => det != reflected.ip(),
_ => false,
};
probe.ipv6_nat_detected = match (probe.detected_ipv6, probe.reflected_ipv6) {
(Some(det), Some(reflected)) => det != reflected.ip(),
_ => false,
};
probe.ipv4_usable = config.ipv4
&& probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false));
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
probe.ipv6_usable = ipv6_enabled
&& probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false));
Ok(probe)
}
pub fn decide_network_capabilities(config: &NetworkConfig, probe: &NetworkProbe) -> NetworkDecision {
let mut decision = NetworkDecision::default();
decision.ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some();
decision.ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
decision.ipv4_me = config.ipv4
&& probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.is_some());
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
decision.ipv6_me = ipv6_enabled
&& probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.is_some());
decision.effective_prefer = match config.prefer {
6 if decision.ipv6_me || decision.ipv6_dc => 6,
4 if decision.ipv4_me || decision.ipv4_dc => 4,
6 => {
warn!("prefer=6 requested but IPv6 unavailable; falling back to IPv4");
4
}
_ => 4,
};
let me_families = decision.ipv4_me as u8 + decision.ipv6_me as u8;
decision.effective_multipath = config.multipath && me_families >= 2;
decision
}
fn detect_local_ip_v4() -> Option<Ipv4Addr> {
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
socket.connect("8.8.8.8:80").ok()?;
match socket.local_addr().ok()?.ip() {
IpAddr::V4(v4) => Some(v4),
_ => None,
}
}
fn detect_local_ip_v6() -> Option<Ipv6Addr> {
let socket = UdpSocket::bind("[::]:0").ok()?;
socket.connect("[2001:4860:4860::8888]:80").ok()?;
match socket.local_addr().ok()?.ip() {
IpAddr::V6(v6) => Some(v6),
_ => None,
}
}
pub fn is_bogon(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_bogon_v4(v4),
IpAddr::V6(v6) => is_bogon_v6(v6),
}
}
pub fn is_bogon_v4(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
return true;
}
if octets[0] == 0 {
return true;
}
if octets[0] == 100 && (octets[1] & 0xC0) == 64 {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 0 {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
return true;
}
if octets[0] == 198 && (octets[1] & 0xFE) == 18 {
return true;
}
if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
return true;
}
if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
return true;
}
if ip.is_multicast() {
return true;
}
if octets[0] >= 240 {
return true;
}
if ip.is_broadcast() {
return true;
}
false
}
pub fn is_bogon_v6(ip: Ipv6Addr) -> bool {
if ip.is_unspecified() || ip.is_loopback() || ip.is_unique_local() {
return true;
}
let segs = ip.segments();
if (segs[0] & 0xFFC0) == 0xFE80 {
return true;
}
if segs[0..5] == [0, 0, 0, 0, 0] && segs[5] == 0xFFFF {
return true;
}
if segs[0] == 0x0100 && segs[1..4] == [0, 0, 0] {
return true;
}
if segs[0] == 0x2001 && segs[1] == 0x0db8 {
return true;
}
if segs[0] == 0x2002 {
return true;
}
if ip.is_multicast() {
return true;
}
false
}
pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) {
info!(
ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
ipv4_bogon = probe.ipv4_is_bogon,
ipv6_bogon = probe.ipv6_is_bogon,
ipv4_me = decision.ipv4_me,
ipv6_me = decision.ipv6_me,
ipv4_dc = decision.ipv4_dc,
ipv6_dc = decision.ipv6_dc,
prefer = decision.effective_prefer,
multipath = decision.effective_multipath,
"Network capabilities resolved"
);
}

203
src/network/stun.rs Normal file
View File

@@ -0,0 +1,203 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use tokio::net::{lookup_host, UdpSocket};
use tokio::time::{timeout, Duration, sleep};
use crate::error::{ProxyError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpFamily {
V4,
V6,
}
#[derive(Debug, Clone, Copy)]
pub struct StunProbeResult {
pub local_addr: SocketAddr,
pub reflected_addr: SocketAddr,
pub family: IpFamily,
}
#[derive(Debug, Default, Clone)]
pub struct DualStunResult {
pub v4: Option<StunProbeResult>,
pub v6: Option<StunProbeResult>,
}
pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
let (v4, v6) = tokio::join!(
stun_probe_family(stun_addr, IpFamily::V4),
stun_probe_family(stun_addr, IpFamily::V6),
);
Ok(DualStunResult {
v4: v4?,
v6: v6?,
})
}
pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> {
use rand::RngCore;
let bind_addr = match family {
IpFamily::V4 => "0.0.0.0:0",
IpFamily::V6 => "[::]:0",
};
let socket = UdpSocket::bind(bind_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
let target_addr = resolve_stun_addr(stun_addr, family).await?;
if let Some(addr) = target_addr {
socket
.connect(addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
} else {
return Ok(None);
}
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::rng().fill_bytes(&mut req[8..20]); // transaction ID
let mut buf = [0u8; 256];
let mut attempt = 0;
let mut backoff = Duration::from_secs(1);
loop {
socket
.send(&req)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
let recv_res = timeout(Duration::from_secs(3), socket.recv(&mut buf)).await;
let n = match recv_res {
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))),
Err(_) => {
attempt += 1;
if attempt >= 3 {
return Ok(None);
}
sleep(backoff).await;
backoff *= 2;
continue;
}
};
if n < 20 {
return Ok(None);
}
let magic = 0x2112A442u32.to_be_bytes();
let txid = &req[8..20];
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 {
break;
}
let family_byte = buf[idx + 1];
let port_bytes = [buf[idx + 2], buf[idx + 3]];
let len_check = match family_byte {
0x01 => 4,
0x02 => 16,
_ => 0,
};
if len_check == 0 || alen < 4 + len_check {
break;
}
let raw_ip = &buf[idx + 4..idx + 4 + len_check];
let mut port = u16::from_be_bytes(port_bytes);
let reflected_ip = if atype == 0x0020 {
port ^= ((magic[0] as u16) << 8) | magic[1] as u16;
match family_byte {
0x01 => {
let ip = [
raw_ip[0] ^ magic[0],
raw_ip[1] ^ magic[1],
raw_ip[2] ^ magic[2],
raw_ip[3] ^ magic[3],
];
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]))
}
0x02 => {
let mut ip = [0u8; 16];
let xor_key = [magic.as_slice(), txid].concat();
for (i, b) in raw_ip.iter().enumerate().take(16) {
ip[i] = *b ^ xor_key[i];
}
IpAddr::V6(Ipv6Addr::from(ip))
}
_ => {
idx += (alen + 3) & !3;
continue;
}
}
} else {
match family_byte {
0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])),
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())),
_ => {
idx += (alen + 3) & !3;
continue;
}
}
};
let reflected_addr = SocketAddr::new(reflected_ip, port);
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
}
_ => {}
}
idx += (alen + 3) & !3;
}
}
Ok(None)
}
async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<SocketAddr>> {
if let Ok(addr) = stun_addr.parse::<SocketAddr>() {
return Ok(match (addr.is_ipv4(), family) {
(true, IpFamily::V4) | (false, IpFamily::V6) => Some(addr),
_ => None,
});
}
let addrs = lookup_host(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;
let target = addrs
.filter(|a| match (a.is_ipv4(), family) {
(true, IpFamily::V4) => true,
(false, IpFamily::V6) => true,
_ => false,
})
.next();
Ok(target)
}

View File

@@ -1,13 +1,13 @@
//! Protocol constants and datacenter addresses
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy;
use std::sync::LazyLock;
// ============= Telegram Datacenters =============
pub const TG_DATACENTER_PORT: u16 = 443;
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
]
});
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
// ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| {
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
m
});
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| {
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
@@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes =============
/// Default buffer size
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
/// the new buffering strategy for better iOS upload performance.
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling
@@ -204,6 +202,17 @@ pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
// ============= RPC Constants (for Middle Proxy) =============
/// RPC Proxy Request
/// RPC Flags (from Erlang mtp_rpc.erl)
pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8;
pub const RPC_FLAG_MAGIC: u32 = 0x1000;
pub const RPC_FLAG_EXTMODE2: u32 = 0x20000;
pub const RPC_FLAG_PAD: u32 = 0x8000000;
pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000;
pub const RPC_FLAG_QUICKACK: u32 = 0x80000000;
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
/// RPC Proxy Answer
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
@@ -230,7 +239,56 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
#[cfg(test)]
// ============= Middle-End Proxy Servers =============
pub const ME_PROXY_PORT: u16 = 8888;
pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
vec![
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888),
]
});
// ============= RPC Constants (u32 native endian) =============
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c
pub const RPC_NONCE_U32: u32 = 0x7acb87aa;
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5;
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda;
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
// mtproto-common.h
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
pub const RPC_PING_U32: u32 = 0x5730a2df;
pub const RPC_PONG_U32: u32 = 0x8430eaa7;
pub const RPC_CRYPTO_NONE_U32: u32 = 0;
pub const RPC_CRYPTO_AES_U32: u32 = 1;
pub mod proxy_flags {
pub const FLAG_HAS_AD_TAG: u32 = 1;
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
pub const FLAG_MAGIC: u32 = 0x1000;
pub const FLAG_EXTMODE2: u32 = 0x20000;
pub const FLAG_PAD: u32 = 0x8000000;
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,10 +1,13 @@
//! MTProto Obfuscation
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr};
use crate::error::Result;
use super::constants::*;
/// Obfuscation parameters from handshake
///
/// Key material is zeroized on drop.
#[derive(Debug, Clone)]
pub struct ObfuscationParams {
/// Key for decrypting client -> proxy traffic
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
pub dc_idx: i16,
}
impl Drop for ObfuscationParams {
fn drop(&mut self) {
self.decrypt_key.zeroize();
self.decrypt_iv.zeroize();
self.encrypt_key.zeroize();
self.encrypt_iv.zeroize();
}
}
impl ObfuscationParams {
/// Parse obfuscation parameters from handshake bytes
/// Returns None if handshake doesn't match any user secret
pub fn from_handshake(
handshake: &[u8; HANDSHAKE_LEN],
secrets: &[(String, Vec<u8>)], // (username, secret_bytes)
secrets: &[(String, Vec<u8>)],
) -> Option<(Self, String)> {
// Extract prekey and IV for decryption
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets {
// Derive decryption key
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret);
@@ -47,26 +56,22 @@ impl ObfuscationParams {
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Create decryptor and decrypt handshake
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
None => continue, // Try next secret
None => continue,
};
// Extract DC index
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
// Derive encryption key
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret);
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
/// Check if nonce is valid (not matching reserved patterns)
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
// Check first byte
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false;
}
// Check first 4 bytes
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false;
}
// Check bytes 4-7
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false;
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
pub fn prepare_tg_nonce(
nonce: &mut [u8; HANDSHAKE_LEN],
proto_tag: ProtoTag,
enc_key_iv: Option<&[u8]>, // For fast mode
enc_key_iv: Option<&[u8]>,
) {
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// For fast mode, copy the reversed enc_key_iv
if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
@@ -160,15 +160,19 @@ pub fn prepare_tg_nonce(
}
/// Encrypt the outgoing nonce for Telegram
/// Legacy helper — **do not use**.
/// WARNING: logic diverges from Python/C reference (SHA256 of 48 bytes, IV from head).
/// Kept only to avoid breaking external callers; prefer `encrypt_tg_nonce_with_ciphers`.
#[deprecated(
note = "Incorrect MTProto obfuscation KDF; use proxy::handshake::encrypt_tg_nonce_with_ciphers"
)]
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// Derive encryption key from the nonce itself
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
// Only encrypt from PROTO_TAG_POS onwards
let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
@@ -182,22 +186,18 @@ mod tests {
#[test]
fn test_is_valid_nonce() {
// Valid nonce
let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid));
// Invalid: starts with 0xef
let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid));
// Invalid: starts with HEAD
let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid));
// Invalid: bytes 4-7 are zeros
let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid));
@@ -214,4 +214,4 @@ mod tests {
assert!(is_valid_nonce(&nonce));
assert_eq!(nonce.len(), HANDSHAKE_LEN);
}
}
}

View File

@@ -4,10 +4,12 @@
//! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM};
use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result};
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
use num_bigint::BigUint;
use num_traits::One;
// ============= Public Constants =============
@@ -311,13 +313,27 @@ pub fn validate_tls_handshake(
None
}
fn curve25519_prime() -> BigUint {
(BigUint::one() << 255) - BigUint::from(19u32)
}
/// Generate a fake X25519 public key for TLS
///
/// This generates random bytes that look like a valid X25519 public key.
/// Since we're not doing real TLS, the actual cryptographic properties don't matter.
pub fn gen_fake_x25519_key() -> [u8; 32] {
let bytes = SECURE_RANDOM.bytes(32);
bytes.try_into().unwrap()
/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
/// which matches Python/C behavior and avoids DPI fingerprinting.
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let mut n_bytes = [0u8; 32];
n_bytes.copy_from_slice(&rng.bytes(32));
let n = BigUint::from_bytes_le(&n_bytes);
let p = curve25519_prime();
let pk = (&n * &n) % &p;
let mut out = pk.to_bytes_le();
out.resize(32, 0);
let mut result = [0u8; 32];
result.copy_from_slice(&out[..32]);
result
}
/// Build TLS ServerHello response
@@ -333,8 +349,9 @@ pub fn build_server_hello(
client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8],
fake_cert_len: usize,
rng: &SecureRandom,
) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key();
let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
@@ -351,7 +368,7 @@ pub fn build_server_hello(
];
// Build fake certificate (Application Data record)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len);
let fake_cert = rng.bytes(fake_cert_len);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
@@ -489,13 +506,25 @@ mod tests {
#[test]
fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key();
let key2 = gen_fake_x25519_key();
let rng = SecureRandom::new();
let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random
}
#[test]
fn test_fake_x25519_key_is_quadratic_residue() {
let rng = SecureRandom::new();
let key = gen_fake_x25519_key(&rng);
let p = curve25519_prime();
let k_num = BigUint::from_bytes_le(&key);
let exponent = (&p - BigUint::one()) >> 1;
let legendre = k_num.modpow(&exponent, &p);
assert_eq!(legendre, BigUint::one());
}
#[test]
fn test_tls_extension_builder() {
@@ -545,7 +574,8 @@ mod tests {
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response = build_server_hello(secret, &client_digest, &session_id, 2048);
let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
// Should have at least 3 records
assert!(response.len() > 100);
@@ -577,8 +607,9 @@ mod tests {
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024);
let rng = SecureRandom::new();
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
// Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
@@ -637,4 +668,4 @@ mod tests {
// Should return None (no match) but not panic
assert!(result.is_none());
}
}
}

View File

@@ -1,34 +1,198 @@
//! Client Handler
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::{debug, info, warn, error, trace};
use tracing::{debug, warn};
/// Post-handshake future (relay phase, runs outside handshake timeout)
type PostHandshakeFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
/// Result of the handshake phase
enum HandshakeOutcome {
/// Handshake succeeded, relay work to do (outside timeout)
NeedsRelay(PostHandshakeFuture),
/// Already fully handled (bad client masking, etc.)
Handled,
}
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result, HandshakeResult};
use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result};
use crate::ip_tracker::UserIpTracker;
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr;
use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::MePool;
use crate::transport::{UpstreamManager, configure_client_socket};
// Use absolute paths to avoid confusion
use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
};
use crate::proxy::relay::relay_bidirectional;
use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy;
pub async fn handle_client_stream<S>(
mut stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>,
) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
stats.increment_connects_all();
debug!(peer = %peer, "New connection (generic stream)");
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake);
let stats_for_timeout = stats.clone();
// For non-TCP streams, use a synthetic local address
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
// Phase 1: handshake (with timeout)
let outcome = match timeout(handshake_timeout, async {
let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls {
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(HandshakeOutcome::Handled);
}
let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream);
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, read_half, write_half, peer,
&config, &replay_checker, &rng,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, tls_reader, tls_writer, peer,
&config, &replay_checker, true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool,
local_addr, peer, ip_tracker.clone(),
),
)))
} else {
if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(HandshakeOutcome::Handled);
}
let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream);
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, read_half, write_half, peer,
&config, &replay_checker, false,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
RunningClientHandler::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer,
ip_tracker.clone(),
)
)))
}
}).await {
Ok(Ok(outcome)) => outcome,
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
return Err(e);
}
Err(_) => {
stats_for_timeout.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
return Err(ProxyError::TgHandshakeTimeout);
}
};
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
}
}
/// Client connection handler (builder struct)
pub struct ClientHandler;
/// Running client handler with stream and context
pub struct RunningClientHandler {
stream: TcpStream,
peer: SocketAddr,
@@ -37,10 +201,12 @@ pub struct RunningClientHandler {
replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>,
}
impl ClientHandler {
/// Create new client handler instance
pub fn new(
stream: TcpStream,
peer: SocketAddr,
@@ -49,6 +215,9 @@ impl ClientHandler {
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>,
) -> RunningClientHandler {
RunningClientHandler {
stream,
@@ -58,19 +227,21 @@ impl ClientHandler {
replay_checker,
upstream_manager,
buffer_pool,
rng,
me_pool,
ip_tracker,
}
}
}
impl RunningClientHandler {
/// Run the client handler
pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all();
let peer = self.peer;
let ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket(
&self.stream,
self.config.timeouts.client_keepalive,
@@ -78,89 +249,76 @@ impl RunningClientHandler {
) {
debug!(peer = %peer, error = %e, "Failed to configure client socket");
}
// Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
// Clone stats for error handling block
let stats = self.stats.clone();
let result = timeout(
handshake_timeout,
self.do_handshake()
).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
// Phase 1: handshake (with timeout)
let outcome = match timeout(handshake_timeout, self.do_handshake()).await {
Ok(Ok(outcome)) => outcome,
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
return Err(e);
}
Err(_) => {
stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
return Err(ProxyError::TgHandshakeTimeout);
}
};
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
}
}
/// Perform handshake and relay
async fn do_handshake(mut self) -> Result<()> {
// Read first bytes to determine handshake type
async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer;
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
let ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls {
self.handle_tls_client(first_bytes).await
} else {
self.handle_direct_client(first_bytes).await
}
}
/// Handle TLS-wrapped client
async fn handle_tls_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer;
// Read TLS handshake length
let ip_tracker = self.ip_tracker.clone();
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad();
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(());
return Ok(HandshakeOutcome::Handled);
}
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields before consuming self.stream
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake,
read_half,
@@ -168,23 +326,25 @@ impl RunningClientHandler {
peer,
&config,
&replay_checker,
).await {
&self.rng,
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
// Read MTProto handshake through TLS
debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
.try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake,
tls_reader,
@@ -193,60 +353,63 @@ impl RunningClientHandler {
&config,
&replay_checker,
true,
).await {
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
HandshakeResult::BadClient {
reader: _,
writer: _,
} => {
stats.increment_connects_bad();
// Valid TLS but invalid MTProto - drop
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
return Ok(());
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config,
buffer_pool
).await
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config,
buffer_pool,
self.rng,
self.me_pool,
local_addr,
peer,
self.ip_tracker,
),
)))
}
/// Handle direct (non-TLS) client
async fn handle_direct_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer;
// Check if non-TLS modes are enabled
let ip_tracker = self.ip_tracker.clone();
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(());
return Ok(HandshakeOutcome::Handled);
}
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake,
read_half,
@@ -255,28 +418,40 @@ impl RunningClientHandler {
&config,
&replay_checker,
false,
).await {
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
};
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config,
buffer_pool
).await
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config,
buffer_pool,
self.rng,
self.me_pool,
local_addr,
peer,
self.ip_tracker,
),
)))
}
/// Static version of handle_authenticated_inner
/// Main dispatch after successful handshake.
/// Two modes:
/// - Direct: TCP relay to TG DC (existing behavior)
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
@@ -285,156 +460,125 @@ impl RunningClientHandler {
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
// Check user limits
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
if let Err(e) = Self::check_user_limits_static(user, &config, &stats, peer_addr, &ip_tracker).await {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
// IP Cleanup Guard: автоматически удаляет IP при выходе из scope
struct IpCleanupGuard {
tracker: Arc<UserIpTracker>,
user: String,
ip: std::net::IpAddr,
}
// Get datacenter address
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
impl Drop for IpCleanupGuard {
fn drop(&mut self) {
let tracker = self.tracker.clone();
let user = self.user.clone();
let ip = self.ip;
tokio::spawn(async move {
tracker.remove_ip(&user, ip).await;
debug!(user = %user, ip = %ip, "IP cleaned up on disconnect");
});
}
}
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
fast_mode = config.general.fast_mode,
"Connecting to Telegram"
);
// Connect to Telegram via UpstreamManager
let tg_stream = upstream_manager.connect(dc_addr).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
// Perform Telegram handshake and get crypto streams
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream,
&success,
&config,
).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
// Update stats
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
// Relay traffic using buffer pool
let relay_result = relay_bidirectional(
let _cleanup = IpCleanupGuard {
tracker: ip_tracker,
user: user.clone(),
ip: peer_addr.ip(),
};
// Decide: middle proxy or direct
if config.general.use_middle_proxy {
if let Some(ref pool) = me_pool {
return handle_via_middle_proxy(
client_reader,
client_writer,
success,
pool.clone(),
stats,
config,
buffer_pool,
local_addr,
rng,
)
.await;
}
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
}
// Direct mode (original behavior)
handle_via_direct(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&stats),
success,
upstream_manager,
stats,
config,
buffer_pool,
).await;
// Update stats
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"),
}
relay_result
rng,
)
.await
}
/// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration
async fn check_user_limits_static(
user: &str,
config: &ProxyConfig,
stats: &Stats,
peer_addr: SocketAddr,
ip_tracker: &UserIpTracker,
) -> Result<()> {
if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() });
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
}
// Check connection limit
// IP limit check
if let Err(reason) = ip_tracker.check_and_add(user, peer_addr.ip()).await {
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
if stats.get_user_curr_connects(user) >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
// Check data quota
if let Some(quota) = config.access.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
if stats.get_user_total_octets(user) >= *quota {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
}
Ok(())
}
/// Get datacenter address by index (static version)
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
datacenters.get(idx)
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
.ok_or_else(|| ProxyError::InvalidHandshake(
format!("Invalid DC index: {}", dc_idx)
))
}
/// Perform handshake with Telegram server (static version)
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
&success.dec_key, // Client's dec key
success.dec_iv,
config.general.fast_mode,
);
// Encrypt nonce
let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
"Sending nonce to Telegram"
);
// Send to Telegram
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
debug!(peer = %success.peer, "Nonce sent to Telegram");
// Split stream and wrap with crypto
let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
let tg_reader = CryptoReader::new(read_half, decryptor);
let tg_writer = CryptoWriter::new(write_half, encryptor);
Ok((tg_reader, tg_writer))
}
}
}

185
src/proxy/direct_relay.rs Normal file
View File

@@ -0,0 +1,185 @@
use std::fs::OpenOptions;
use std::io::Write;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::Result;
use crate::protocol::constants::*;
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
let dc_addr = get_dc_addr_static(success.dc_idx, &config)?;
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
mode = "direct",
"Connecting to Telegram DC"
);
let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx))
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
let (tg_reader, tg_writer) =
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&stats),
buffer_pool,
)
.await;
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
}
relay_result
}
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true);
let datacenters = if prefer_v6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
let num_dcs = datacenters.len();
let dc_key = dc_idx.to_string();
if let Some(addrs) = config.dc_overrides.get(&dc_key) {
let mut parsed = Vec::new();
for addr_str in addrs {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => parsed.push(addr),
Err(_) => warn!(dc_idx = dc_idx, addr_str = %addr_str, "Invalid DC override address in config, ignoring"),
}
}
if let Some(addr) = parsed
.iter()
.find(|a| a.is_ipv6() == prefer_v6)
.or_else(|| parsed.first())
.copied()
{
debug!(dc_idx = dc_idx, addr = %addr, count = parsed.len(), "Using DC override from config");
return Ok(addr);
}
}
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
// Unknown DC requested by client without override: log and fall back.
if !config.dc_overrides.contains_key(&dc_key) {
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
if let Some(path) = &config.general.unknown_dc_log_path {
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
let _ = writeln!(file, "dc_idx={dc_idx}");
}
}
}
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(
datacenters[fallback_idx],
TG_DATACENTER_PORT,
))
}
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(
CryptoReader<tokio::net::tcp::OwnedReadHalf>,
CryptoWriter<tokio::net::tcp::OwnedWriteHalf>,
)> {
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
success.dc_idx,
&success.dec_key,
success.dec_iv,
&success.enc_key,
success.enc_iv,
rng,
config.general.fast_mode,
);
let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
"Sending nonce to Telegram"
);
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
let (read_half, write_half) = stream.into_split();
Ok((
CryptoReader::new(read_half, tg_decryptor),
CryptoWriter::new(write_half, tg_encryptor),
))
}

View File

@@ -1,11 +1,11 @@
//! MTProto Handshake Magics
//! MTProto Handshake
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info};
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr};
use crate::crypto::random::SECURE_RANDOM;
use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig;
/// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)]
pub struct HandshakeSuccess {
/// Authenticated user name
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
pub is_tls: bool,
}
impl Drop for HandshakeSuccess {
fn drop(&mut self) {
self.dec_key.zeroize();
self.dec_iv.zeroize();
self.enc_key.zeroize();
self.enc_iv.zeroize();
}
}
/// Handle fake TLS handshake
pub async fn handle_tls_handshake<R, W>(
handshake: &[u8],
@@ -42,37 +54,33 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
rng: &SecureRandom,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
}
// Extract digest for replay check
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
// Check for replay
if replay_checker.check_tls_digest(digest_half) {
if replay_checker.check_and_add_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
// Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
})
.collect();
// Validate handshake
let validation = match tls::validate_tls_handshake(
handshake,
&secrets,
@@ -88,42 +96,38 @@ where
return HandshakeResult::BadClient { reader, writer };
}
};
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
};
// Build and send response
let response = tls::build_server_hello(
secret,
&validation.digest,
&validation.session_id,
config.censorship.fake_cert_len,
rng,
);
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await {
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e));
}
if let Err(e) = writer.flush().await {
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e));
}
// Record for replay protection only after successful handshake
replay_checker.add_tls_digest(digest_half);
info!(
peer = %peer,
user = %validation.user,
"TLS handshake successful"
);
HandshakeResult::Success((
FakeTlsReader::new(reader),
FakeTlsWriter::new(writer),
@@ -146,87 +150,72 @@ where
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
// Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) {
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret
for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) {
Ok(s) => s,
Err(_) => continue,
};
// Derive decryption key
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Decrypt handshake to check protocol tag
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
None => continue,
};
// Check if mode is enabled
let mode_ok = match proto_tag {
ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
};
if !mode_ok {
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
continue;
}
// Extract DC index
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
// Derive encryption key
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
// Record for replay protection
replay_checker.add_handshake(dec_prekey_iv);
// Create new cipher instances
let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv);
let success = HandshakeSuccess {
user: user.clone(),
dc_idx,
@@ -238,7 +227,7 @@ where
peer,
is_tls,
};
info!(
peer = %peer,
user = %user,
@@ -247,14 +236,14 @@ where
tls = is_tls,
"MTProto handshake successful"
);
return HandshakeResult::Success((
CryptoReader::new(reader, decryptor),
CryptoWriter::new(writer, encryptor),
success,
));
}
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }
}
@@ -262,94 +251,154 @@ where
/// Generate nonce for Telegram connection
pub fn generate_tg_nonce(
proto_tag: ProtoTag,
client_dec_key: &[u8; 32],
client_dec_iv: u128,
dc_idx: i16,
_client_dec_key: &[u8; 32],
_client_dec_iv: u128,
client_enc_key: &[u8; 32],
client_enc_iv: u128,
rng: &SecureRandom,
fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode {
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes());
let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
key_iv.extend_from_slice(client_enc_key);
key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&key_iv);
}
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
}
}
/// Encrypt nonce for sending to Telegram
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv);
let encrypted_full = encryptor.encrypt(nonce);
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
result
let decryptor = AesCtr::new(&dec_key, dec_iv);
(result, encryptor, decryptor)
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
// Check length
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert_eq!(nonce.len(), HANDSHAKE_LEN);
// Check proto tag is set
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
}
#[test]
fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false);
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
// First PROTO_TAG_POS bytes should be unchanged
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
// Rest should be different (encrypted)
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
}
}
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
}

View File

@@ -3,12 +3,17 @@
use std::time::Duration;
use std::str;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::debug;
use crate::config::ProxyConfig;
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
/// Maximum duration for the entire masking relay.
/// Limits resource consumption from slow-loris attacks and port scanners.
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
@@ -21,33 +26,33 @@ fn detect_client_type(data: &[u8]) -> &'static str {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client<R, W>(
mut reader: R,
mut writer: W,
reader: R,
writer: W,
initial_data: &[u8],
config: &ProxyConfig,
)
where
)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
@@ -56,13 +61,41 @@ where
consume_client_data(reader).await;
return;
}
let client_type = detect_client_type(initial_data);
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
debug!(
client_type = client_type,
sock = %sock_path,
data_len = initial_data.len(),
"Forwarding bad client to mask unix socket"
);
let connect_result = timeout(MASK_TIMEOUT, UnixStream::connect(sock_path)).await;
match connect_result {
Ok(Ok(stream)) => {
let (mask_read, mask_write) = stream.into_split();
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
}
Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data(reader).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data(reader).await;
}
}
return;
}
let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!(
client_type = client_type,
host = %mask_host,
@@ -70,35 +103,45 @@ where
data_len = initial_data.len(),
"Forwarding bad client to mask host"
);
// Connect to mask host
let mask_addr = format!("{}:{}", mask_host, mask_port);
let connect_result = timeout(
MASK_TIMEOUT,
TcpStream::connect(&mask_addr)
).await;
let mask_stream = match connect_result {
Ok(Ok(s)) => s,
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {
let (mask_read, mask_write) = stream.into_split();
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
}
Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host");
consume_client_data(reader).await;
return;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data(reader).await;
return;
}
};
let (mut mask_read, mut mask_write) = mask_stream.into_split();
}
}
/// Relay traffic between client and mask backend
async fn relay_to_mask<R, W, MR, MW>(
mut reader: R,
mut writer: W,
mut mask_read: MR,
mut mask_write: MW,
initial_data: &[u8],
)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static,
MW: AsyncWrite + Unpin + Send + 'static,
{
// Send initial data to mask host
if mask_write.write_all(initial_data).await.is_err() {
return;
}
// Relay traffic
let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
@@ -116,7 +159,7 @@ where
}
}
});
let m2c = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
@@ -133,7 +176,7 @@ where
}
}
});
// Wait for either to complete
tokio::select! {
_ = c2m => {}
@@ -149,4 +192,4 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
break;
}
}
}
}

295
src/proxy/middle_relay.rs Normal file
View File

@@ -0,0 +1,295 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info, trace};
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
mut crypto_writer: CryptoWriter<W>,
success: HandshakeSuccess,
me_pool: Arc<MePool>,
stats: Arc<Stats>,
_config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let peer = success.peer;
let proto_tag = success.proto_tag;
info!(
user = %user,
peer = %peer,
dc = success.dc_idx,
proto = ?proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
let (conn_id, mut me_rx) = me_pool.registry().register().await;
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag());
debug!(
user = %user,
conn_id,
proto_flags = format_args!("0x{:08x}", proto_flags),
"ME relay started"
);
let translated_local_addr = me_pool.translate_our_addr(local_addr);
let result: Result<()> = loop {
tokio::select! {
client_frame = read_client_payload(&mut crypto_reader, proto_tag) => {
match client_frame {
Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64);
let mut flags = proto_flags;
if quickack {
flags |= RPC_FLAG_QUICKACK;
}
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED;
}
me_pool.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
&payload,
flags,
).await?;
}
Ok(None) => {
debug!(conn_id, "Client EOF");
let _ = me_pool.send_close(conn_id).await;
break Ok(());
}
Err(e) => break Err(e),
}
}
me_msg = me_rx.recv() => {
match me_msg {
Some(MeResponse::Data { flags, data }) => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats.add_user_octets_to(&user, data.len() as u64);
write_client_payload(&mut crypto_writer, proto_tag, flags, &data, rng.as_ref()).await?;
}
Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack");
write_client_ack(&mut crypto_writer, proto_tag, confirm).await?;
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close");
break Ok(());
}
None => {
debug!(conn_id, "ME channel closed");
break Err(ProxyError::Proxy("ME connection lost".into()));
}
}
}
}
};
debug!(user = %user, conn_id, "ME relay cleanup");
me_pool.registry().unregister(conn_id).await;
stats.decrement_user_curr_connects(&user);
result
}
async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag,
) -> Result<Option<(Vec<u8>, bool)>>
where
R: AsyncRead + Unpin + Send + 'static,
{
let (len, quickack) = match proto_tag {
ProtoTag::Abridged => {
let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
let len = len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack)
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let quickack = (len_buf[3] & 0x80) != 0;
((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack)
}
};
if len > 16 * 1024 * 1024 {
return Err(ProxyError::Proxy(format!("Frame too large: {len}")));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes)
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
}
}
Ok(Some((payload, quickack)))
}
async fn write_client_payload<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
flags: u32,
data: &[u8],
rng: &SecureRandom,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let quickack = (flags & RPC_FLAG_QUICKACK) != 0;
match proto_tag {
ProtoTag::Abridged => {
if data.len() % 4 != 0 {
return Err(ProxyError::Proxy(format!(
"Abridged payload must be 4-byte aligned, got {}",
data.len()
)));
}
let len_words = data.len() / 4;
if len_words < 0x7f {
let mut first = len_words as u8;
if quickack {
first |= 0x80;
}
client_writer
.write_all(&[first])
.await
.map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) {
let mut first = 0x7fu8;
if quickack {
first |= 0x80;
}
let lw = (len_words as u32).to_le_bytes();
client_writer
.write_all(&[first, lw[0], lw[1], lw[2]])
.await
.map_err(ProxyError::Io)?;
} else {
return Err(ProxyError::Proxy(format!(
"Abridged frame too large: {}",
data.len()
)));
}
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure {
(rng.bytes(1)[0] % 4) as usize
} else {
0
};
let mut len = (data.len() + padding_len) as u32;
if quickack {
len |= 0x8000_0000;
}
client_writer
.write_all(&len.to_le_bytes())
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 {
let pad = rng.bytes(padding_len);
client_writer
.write_all(&pad)
.await
.map_err(ProxyError::Io)?;
}
}
}
// Avoid unconditional per-frame flush (throughput killer on large downloads).
// Flush only when low-latency ack semantics are requested or when
// CryptoWriter has buffered pending ciphertext that must be drained.
if quickack || client_writer.has_pending() {
client_writer.flush().await.map_err(ProxyError::Io)?;
}
Ok(())
}
async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
confirm: u32,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let bytes = if proto_tag == ProtoTag::Abridged {
confirm.to_be_bytes()
} else {
confirm.to_le_bytes()
};
client_writer
.write_all(&bytes)
.await
.map_err(ProxyError::Io)?;
// ACK should remain low-latency.
client_writer.flush().await.map_err(ProxyError::Io)
}

View File

@@ -1,11 +1,13 @@
//! Proxy Defs
pub mod handshake;
pub mod client;
pub mod relay;
pub mod direct_relay;
pub mod handshake;
pub mod masking;
pub mod middle_relay;
pub mod relay;
pub use handshake::*;
pub use client::ClientHandler;
pub use handshake::*;
pub use masking::*;
pub use relay::*;
pub use masking::*;

View File

@@ -1,27 +1,320 @@
//! Bidirectional Relay
//! Bidirectional Relay — poll-based, no head-of-line blocking
//!
//! ## What changed and why
//!
//! Previous implementation used a single-task `select! { biased; ... }` loop
//! where each branch called `write_all()`. This caused head-of-line blocking:
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
//! media), the entire loop was blocked — the other direction couldn't make progress.
//!
//! Symptoms observed in production:
//! - Media loading at ~8 KB/s despite fast server connection
//! - Stop-and-go pattern with 50500ms gaps between chunks
//! - `biased` select starving S→C direction
//! - Some users unable to load media at all
//!
//! ## New architecture
//!
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
//!
//! Old (select! + write_all — BLOCKING):
//!
//! loop {
//! select! {
//! biased;
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
//! data = server.read() => { client.write_all(data).await; } ← can't run
//! }
//! }
//!
//! New (copy_bidirectional — CONCURRENT):
//!
//! poll(cx) {
//! // Both directions polled in the same poll cycle
//! C→S: poll_read(client) → poll_write(server) // non-blocking
//! S→C: poll_read(server) → poll_write(client) // non-blocking
//! // If one writer is Pending, the other direction still progresses
//! }
//!
//! Benefits:
//! - No head-of-line blocking: slow client download doesn't block uploads
//! - No biased starvation: fair polling of both directions
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
//! poll-based approach lets TCP flow control work correctly
//!
//! Stats tracking:
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
use tokio::time::Instant;
use tracing::{debug, trace, warn, info};
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::atomic::{AtomicU64, Ordering};
// Activity timeout for iOS compatibility (30 minutes)
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
// ============= Constants =============
/// Relay data bidirectionally between client and server
/// Activity timeout for iOS compatibility.
///
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
/// Closing earlier causes unnecessary reconnects and handshake overhead.
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
/// Watchdog check interval — also used for periodic rate logging.
///
/// 10 seconds gives responsive timeout detection (±10s accuracy)
/// without measurable overhead from atomic reads.
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
// ============= CombinedStream =============
/// Combines separate read and write halves into a single bidirectional stream.
///
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
/// but the handshake layer produces split reader/writer pairs
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
///
/// This wrapper reunifies them with zero overhead — each trait method
/// delegates directly to the corresponding half. No buffering, no copies.
///
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
/// so there's no aliasing even though both are called on the same `&mut self`.
struct CombinedStream<R, W> {
reader: R,
writer: W,
}
impl<R, W> CombinedStream<R, W> {
fn new(reader: R, writer: W) -> Self {
Self { reader, writer }
}
}
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
}
}
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}
// ============= SharedCounters =============
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
///
/// Using `Relaxed` ordering is sufficient because:
/// - Counters are monotonically increasing (no ABA problem)
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
/// - No ordering dependencies between different counters
struct SharedCounters {
/// Bytes read from client (C→S direction)
c2s_bytes: AtomicU64,
/// Bytes written to client (S→C direction)
s2c_bytes: AtomicU64,
/// Number of poll_read completions (≈ C→S chunks)
c2s_ops: AtomicU64,
/// Number of poll_write completions (≈ S→C chunks)
s2c_ops: AtomicU64,
/// Milliseconds since relay epoch of last I/O activity
last_activity_ms: AtomicU64,
}
impl SharedCounters {
fn new() -> Self {
Self {
c2s_bytes: AtomicU64::new(0),
s2c_bytes: AtomicU64::new(0),
c2s_ops: AtomicU64::new(0),
s2c_ops: AtomicU64::new(0),
last_activity_ms: AtomicU64::new(0),
}
}
/// Record activity at this instant.
#[inline]
fn touch(&self, now: Instant, epoch: Instant) {
let ms = now.duration_since(epoch).as_millis() as u64;
self.last_activity_ms.store(ms, Ordering::Relaxed);
}
/// How long since last recorded activity.
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
let now_ms = now.duration_since(epoch).as_millis() as u64;
Duration::from_millis(now_ms.saturating_sub(last_ms))
}
}
// ============= StatsIo =============
/// Transparent I/O wrapper that tracks per-user statistics and activity.
///
/// Wraps the **client** side of the relay. Direction mapping:
///
/// | poll method | direction | stats updated |
/// |-------------|-----------|--------------------------------------|
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
///
/// Both update the shared activity timestamp for the watchdog.
///
/// Note on message counts: the original code counted one `read()`/`write_all()`
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
/// Byte counts are identical; op counts may differ slightly due to different
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
struct StatsIo<S> {
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
}
impl<S> StatsIo<S> {
fn new(
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
}
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
// C→S: client sent data
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
}
other => other,
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
// ============= Relay =============
/// Relay data bidirectionally between client and server.
///
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
///
/// ## API compatibility
///
/// Signature is identical to the previous implementation. The `_buffer_pool`
/// parameter is retained for call-site compatibility — `copy_bidirectional`
/// manages its own internal buffers (8 KB per direction).
///
/// ## Guarantees preserved
///
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut client_reader: CR,
mut client_writer: CW,
mut server_reader: SR,
mut server_writer: SW,
client_reader: CR,
client_writer: CW,
server_reader: SR,
server_writer: SW,
user: &str,
stats: Arc<Stats>,
buffer_pool: Arc<BufferPool>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
@@ -29,234 +322,145 @@ where
SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static,
{
let user_c2s = user.to_string();
let user_s2c = user.to_string();
let stats_c2s = Arc::clone(&stats);
let stats_s2c = Arc::clone(&stats);
let c2s_bytes = Arc::new(AtomicU64::new(0));
let s2c_bytes = Arc::new(AtomicU64::new(0));
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
let pool_c2s = buffer_pool.clone();
let pool_s2c = buffer_pool.clone();
// Client -> Server task
let c2s = tokio::spawn(async move {
// Get buffer from pool
let mut pooled_buf = pool_c2s.get();
// CRITICAL FIX: BytesMut from pool has len 0. We must resize it to be usable as &mut [u8].
// We use the full capacity.
let cap = pooled_buf.capacity();
pooled_buf.resize(cap, 0);
let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop {
// Read with timeout
let read_result = tokio::time::timeout(
activity_timeout,
client_reader.read(&mut pooled_buf)
).await;
match read_result {
Err(_) => {
warn!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (C->S) - no data received"
);
let _ = server_writer.shutdown().await;
break;
}
Ok(Ok(0)) => {
debug!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
"Client closed connection (C->S)"
);
let _ = server_writer.shutdown().await;
break;
}
Ok(Ok(n)) => {
total_bytes += n as u64;
msg_count += 1;
last_activity = Instant::now();
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
stats_c2s.increment_user_msgs_from(&user_c2s);
trace!(
user = %user_c2s,
bytes = n,
total = total_bytes,
"C->S data"
);
// Log activity every 10 seconds with correct rate
let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
debug!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = server_writer.write_all(&pooled_buf[..n]).await {
debug!(user = %user_c2s, error = %e, "Failed to write to server");
break;
}
if let Err(e) = server_writer.flush().await {
debug!(user = %user_c2s, error = %e, "Failed to flush to server");
break;
}
}
Ok(Err(e)) => {
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
break;
}
}
}
});
// Server -> Client task
let s2c = tokio::spawn(async move {
// Get buffer from pool
let mut pooled_buf = pool_s2c.get();
// CRITICAL FIX: Resize buffer
let cap = pooled_buf.capacity();
pooled_buf.resize(cap, 0);
let epoch = Instant::now();
let counters = Arc::new(SharedCounters::new());
let user_owned = user.to_string();
// ── Combine split halves into bidirectional streams ──────────────
let client_combined = CombinedStream::new(client_reader, client_writer);
let mut server = CombinedStream::new(server_reader, server_writer);
// Wrap client with stats/activity tracking
let mut client = StatsIo::new(
client_combined,
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
epoch,
);
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let watchdog = async {
let mut prev_c2s: u64 = 0;
let mut prev_s2c: u64 = 0;
let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop {
let read_result = tokio::time::timeout(
activity_timeout,
server_reader.read(&mut pooled_buf)
).await;
match read_result {
Err(_) => {
warn!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (S->C) - no data received"
);
let _ = client_writer.shutdown().await;
break;
}
Ok(Ok(0)) => {
debug!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
"Server closed connection (S->C)"
);
let _ = client_writer.shutdown().await;
break;
}
Ok(Ok(n)) => {
total_bytes += n as u64;
msg_count += 1;
last_activity = Instant::now();
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
stats_s2c.increment_user_msgs_to(&user_s2c);
trace!(
user = %user_s2c,
bytes = n,
total = total_bytes,
"S->C data"
);
let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
debug!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = client_writer.write_all(&pooled_buf[..n]).await {
debug!(user = %user_s2c, error = %e, "Failed to write to client");
break;
}
if let Err(e) = client_writer.flush().await {
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
break;
}
}
Ok(Err(e)) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break;
}
tokio::time::sleep(WATCHDOG_INTERVAL).await;
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
warn!(
user = %wd_user,
c2s_bytes = c2s,
s2c_bytes = s2c,
idle_secs = idle.as_secs(),
"Activity timeout"
);
return; // Causes select! to cancel copy_bidirectional
}
// ── Periodic rate logging ───────────────────────────────
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
let c2s_delta = c2s - prev_c2s;
let s2c_delta = s2c - prev_s2c;
if c2s_delta > 0 || s2c_delta > 0 {
let secs = WATCHDOG_INTERVAL.as_secs_f64();
debug!(
user = %wd_user,
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
c2s_total = c2s,
s2c_total = s2c,
"Relay active"
);
}
prev_c2s = c2s;
prev_s2c = s2c;
}
});
// Wait for either direction to complete
tokio::select! {
result = c2s => {
if let Err(e) = result {
warn!(error = %e, "C->S task panicked");
}
};
// ── Run bidirectional copy + watchdog concurrently ───────────────
//
// copy_bidirectional polls both directions in the same poll() call:
// C→S: poll_read(client/StatsIo) → poll_write(server)
// S→C: poll_read(server) → poll_write(client/StatsIo)
//
// When one direction's writer returns Pending, the other direction
// continues — no head-of-line blocking.
//
// When the watchdog fires, select! drops the copy future,
// releasing the &mut borrows on client and server.
let copy_result = tokio::select! {
result = copy_bidirectional(&mut client, &mut server) => Some(result),
_ = watchdog => None, // Activity timeout — cancel relay
};
// ── Clean shutdown ──────────────────────────────────────────────
// After select!, the losing future is dropped, borrows released.
// Shut down both write sides for clean TCP FIN.
let _ = client.shutdown().await;
let _ = server.shutdown().await;
// ── Final logging ───────────────────────────────────────────────
let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
let duration = epoch.elapsed();
match copy_result {
Some(Ok((c2s, s2c))) => {
// Normal completion — one side closed the connection
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Relay finished"
);
Ok(())
}
result = s2c => {
if let Err(e) = result {
warn!(error = %e, "S->C task panicked");
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
error = %e,
"Relay error"
);
Err(e.into())
}
None => {
// Activity timeout (watchdog fired)
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Relay finished (activity timeout)"
);
Ok(())
}
}
debug!(
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
"Relay finished"
);
Ok(())
}

View File

@@ -1,31 +1,28 @@
//! Statistics
//! Statistics and replay protection
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Instant, Duration};
use dashmap::DashMap;
use parking_lot::{RwLock, Mutex};
use parking_lot::Mutex;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
use tracing::debug;
// ============= Stats =============
/// Thread-safe statistics
#[derive(Default)]
pub struct Stats {
// Global counters
connects_all: AtomicU64,
connects_bad: AtomicU64,
handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>,
// Start time
start_time: RwLock<Option<Instant>>,
start_time: parking_lot::RwLock<Option<Instant>>,
}
/// Per-user statistics
#[derive(Default)]
pub struct UserStats {
pub connects: AtomicU64,
@@ -43,42 +40,20 @@ impl Stats {
stats
}
// Global stats
pub fn increment_connects_all(&self) {
self.connects_all.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.connects.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.curr_connects.fetch_add(1, Ordering::Relaxed);
}
pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -88,47 +63,33 @@ impl Stats {
}
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats
.get(user)
self.user_stats.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats
.get(user)
self.user_stats.get(user)
.map(|s| {
s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed)
@@ -136,6 +97,12 @@ impl Stats {
.unwrap_or(0)
}
pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) }
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> {
self.user_stats.iter()
}
pub fn uptime_secs(&self) -> f64 {
self.start_time.read()
.map(|t| t.elapsed().as_secs_f64())
@@ -143,57 +110,222 @@ impl Stats {
}
}
/// Sharded Replay attack checker using LRU cache
/// Uses multiple independent LRU caches to reduce lock contention
// ============= Replay Checker =============
pub struct ReplayChecker {
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
}
impl ReplayChecker {
/// Create new replay checker with specified capacity per shard
/// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap)));
}
struct ReplayEntry {
seen_at: Instant,
seq: u64,
}
struct ReplayShard {
cache: LruCache<Box<[u8]>, ReplayEntry>,
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
seq_counter: u64,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
shards,
shard_mask: num_shards - 1,
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
}
}
fn get_shard(&self, key: &[u8]) -> usize {
fn next_seq(&mut self) -> u64 {
self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
// between Borrow<[u8]> and Borrow<Box<[u8]>>
if let Some(entry) = self.cache.peek(key.as_ref()) {
if entry.seq == queue_seq {
self.cache.pop(key.as_ref());
}
}
}
}
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
self.cleanup(now, window);
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
self.cache.get(key).is_some()
}
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
self.cleanup(now, window);
let seq = self.next_seq();
let boxed_key: Box<[u8]> = key.into();
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
self.queue.push_back((now, boxed_key, seq));
}
fn len(&self) -> usize {
self.cache.len()
}
}
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
shard_mask: num_shards - 1,
window,
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
}
}
fn get_shard_idx(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check_and_add_internal(&self, data: &[u8]) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let now = Instant::now();
let found = shard.check(data, now, self.window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add(data, now, self.window);
self.additions.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8]) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
}
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
}
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
}
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
pub fn check_handshake(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
ReplayStats {
total_entries,
total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
window_secs: self.window.as_secs(),
}
}
pub fn add_handshake(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs(self.window.as_secs().max(1) / 2)
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 { 0.0 }
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
}
pub fn add_tls_digest(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 { 0.0 }
else { self.total_queue_len as f64 / self.total_entries as f64 }
}
}
@@ -204,28 +336,59 @@ mod tests {
#[test]
fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new());
let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats);
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_all();
assert_eq!(stats.get_connects_all(), 3);
}
#[test]
fn test_replay_checker_sharding() {
let checker = ReplayChecker::new(100);
let data1 = b"test1";
let data2 = b"test2";
checker.add_handshake(data1);
assert!(checker.check_handshake(data1));
assert!(!checker.check_handshake(data2));
checker.add_handshake(data2);
assert!(checker.check_handshake(data2));
fn test_replay_checker_basic() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
assert!(!checker.check_handshake(b"test1")); // first time, inserts
assert!(checker.check_handshake(b"test1")); // duplicate
assert!(!checker.check_handshake(b"test2")); // new key inserts
}
}
#[test]
fn test_replay_checker_duplicate_add() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"dup");
checker.add_handshake(b"dup");
assert!(checker.check_handshake(b"dup"));
}
#[test]
fn test_replay_checker_expiration() {
let checker = ReplayChecker::new(100, Duration::from_millis(50));
assert!(!checker.check_handshake(b"expire"));
assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire"));
}
#[test]
fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
assert!(!checker.check_handshake(b"k1"));
assert!(!checker.check_handshake(b"k2"));
assert!(checker.check_handshake(b"k1"));
assert!(!checker.check_handshake(b"k3"));
let stats = checker.stats();
assert_eq!(stats.total_additions, 3);
assert_eq!(stats.total_checks, 4);
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add_only(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check_handshake(&i.to_le_bytes()));
}
assert_eq!(checker.stats().total_entries, 500);
}
}

View File

@@ -381,9 +381,14 @@ mod tests {
// Add a buffer to pool
pool.preallocate(1);
// Now try_get should succeed
assert!(pool.try_get().is_some());
// Now try_get should succeed once while the buffer is held
let buf = pool.try_get();
assert!(buf.is_some());
// While buffer is held, pool is empty
assert!(pool.try_get().is_none());
// Drop buffer -> returns to pool, should be obtainable again
drop(buf);
assert!(pool.try_get().is_some());
}
#[test]
@@ -448,4 +453,4 @@ mod tests {
// All buffers should be returned
assert!(stats.pooled > 0);
}
}
}

View File

@@ -5,8 +5,10 @@
use bytes::{Bytes, BytesMut};
use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
// ============= Frame Types =============
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
// ============= Codec Factory =============
/// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
match proto_tag {
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
}
}

View File

@@ -5,9 +5,11 @@
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec =============
@@ -21,14 +23,17 @@ pub struct FrameCodec {
proto_tag: ProtoTag,
/// Maximum allowed frame size
max_frame_size: usize,
/// RNG for secure padding
rng: Arc<SecureRandom>,
}
impl FrameCodec {
/// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag) -> Self {
pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
Self {
proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default
rng,
}
}
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
}
}
}
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
use crate::crypto::random::SECURE_RANDOM;
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
// Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned
(SECURE_RANDOM.range(3) + 1) as usize
(rng.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
SECURE_RANDOM.range(4) as usize
rng.range(4) as usize
};
let total_len = data.len() + padding_len;
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(data);
if padding_len > 0 {
let padding = SECURE_RANDOM.bytes(padding_len);
let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding);
}
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
/// Secure Intermediate protocol codec
pub struct SecureCodec {
max_frame_size: usize,
rng: Arc<SecureRandom>,
}
impl SecureCodec {
pub fn new() -> Self {
pub fn new(rng: Arc<SecureRandom>) -> Self {
Self {
max_frame_size: 16 * 1024 * 1024,
rng,
}
}
}
impl Default for SecureCodec {
fn default() -> Self {
Self::new()
Self::new(Arc::new(SecureRandom::new()))
}
}
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst)
encode_secure(&frame, dst, &self.rng)
}
}
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_secure(frame, dst)?;
encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before)
}
@@ -506,6 +511,8 @@ mod tests {
use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex;
use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use std::sync::Arc;
#[tokio::test]
async fn test_framed_abridged() {
@@ -541,8 +548,8 @@ mod tests {
async fn test_framed_secure() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new());
let mut reader = FramedRead::new(server, SecureCodec::new());
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
@@ -557,8 +564,8 @@ mod tests {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag));
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
// Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
@@ -607,7 +614,7 @@ mod tests {
#[test]
fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate)
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100);
// Create a "frame" that claims to be very large

View File

@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*;
use crate::crypto::crc32;
use crate::crypto::random::SECURE_RANDOM;
use crate::crypto::{crc32, SecureRandom};
use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame =============
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
/// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> {
upstream: W,
rng: Arc<SecureRandom>,
}
impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
Self { upstream, rng }
}
}
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
}
// Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4);
let padding = SECURE_RANDOM.bytes(padding_len);
let padding_len = self.rng.range(4);
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
}
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
}
}
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
mod tests {
use super::*;
use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
#[tokio::test]
async fn test_abridged_roundtrip() {
@@ -539,7 +542,7 @@ mod tests {
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client);
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
@@ -572,7 +575,7 @@ mod tests {
async fn test_frame_reader_kind() {
let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4];

View File

@@ -32,7 +32,7 @@
//! and uploads from iOS will break (media/file sending), while small traffic
//! may still work.
use bytes::{Bytes, BytesMut, BufMut};
use bytes::{Bytes, BytesMut};
use std::io::{self, Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -51,9 +51,10 @@ use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
/// TLS record header size (type + version + length)
const TLS_HEADER_SIZE: usize = 5;
/// Maximum TLS fragment size per spec (plaintext fragment).
/// We use this for *outgoing* chunking, because we build plain ApplicationData records.
const MAX_TLS_PAYLOAD: usize = 16384;
/// Maximum TLS fragment size we emit for Application Data.
/// Real TLS 1.3 ciphertexts often add ~16-24 bytes AEAD overhead, so to mimic
/// on-the-wire record sizes we allow up to 16384 + 24 bytes of plaintext.
const MAX_TLS_PAYLOAD: usize = 16384 + 24;
/// Maximum pending write buffer for one record remainder.
/// Note: we never queue unlimited amount of data here; state holds at most one record.
@@ -918,10 +919,8 @@ mod tests {
let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()];
tls_reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, payload);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
@@ -935,13 +934,11 @@ mod tests {
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf1 = vec![0u8; payload1.len()];
tls_reader.read_exact(&mut buf1).await.unwrap();
assert_eq!(&buf1, payload1);
let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
assert_eq!(&buf1[..], payload1);
let mut buf2 = vec![0u8; payload2.len()];
tls_reader.read_exact(&mut buf2).await.unwrap();
assert_eq!(&buf2, payload2);
let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
assert_eq!(&buf2[..], payload2);
}
#[tokio::test]
@@ -953,10 +950,9 @@ mod tests {
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()];
tls_reader.read_exact(&mut buf).await.unwrap();
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf, payload);
assert_eq!(&buf[..], payload);
}
#[tokio::test]
@@ -967,10 +963,9 @@ mod tests {
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()];
tls_reader.read_exact(&mut buf).await.unwrap();
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf, payload);
assert_eq!(&buf[..], payload);
}
#[tokio::test]
@@ -983,10 +978,9 @@ mod tests {
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()];
tls_reader.read_exact(&mut buf).await.unwrap();
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf, payload);
assert_eq!(&buf[..], payload);
}
#[tokio::test]
@@ -1000,10 +994,9 @@ mod tests {
let reader = ChunkedReader::new(&data, 3); // Small chunks
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; payload.len()];
tls_reader.read_exact(&mut buf).await.unwrap();
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf, payload);
assert_eq!(&buf[..], payload);
}
#[tokio::test]
@@ -1244,4 +1237,4 @@ mod tests {
let bytes = header.to_bytes();
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
}
}
}

View File

@@ -0,0 +1,184 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload);
let c = crc32(&frame);
frame.extend_from_slice(&c.to_le_bytes());
frame
}
pub(crate) async fn read_rpc_frame_plaintext(
rd: &mut (impl AsyncReadExt + Unpin),
) -> Result<(i32, Vec<u8>)> {
let mut len_buf = [0u8; 4];
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
let total_len = u32::from_le_bytes(len_buf) as usize;
if !(12..=(1 << 24)).contains(&total_len) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad RPC frame length: {total_len}"
)));
}
let mut rest = vec![0u8; total_len - 4];
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
let mut full = Vec::with_capacity(total_len);
full.extend_from_slice(&len_buf);
full.extend_from_slice(&rest);
let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]);
if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
)));
}
let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap());
let payload = full[8..crc_offset].to_vec();
Ok((seq_no, payload))
}
pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
p[16..32].copy_from_slice(nonce);
p
}
pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, u32, [u8; 16])> {
if d.len() < 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Nonce payload too short: {} bytes",
d.len()
)));
}
let t = u32::from_le_bytes(d[0..4].try_into().unwrap());
if t != RPC_NONCE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}"
)));
}
let key_select = u32::from_le_bytes(d[4..8].try_into().unwrap());
let schema = u32::from_le_bytes(d[8..12].try_into().unwrap());
let ts = u32::from_le_bytes(d[12..16].try_into().unwrap());
let mut nonce = [0u8; 16];
nonce.copy_from_slice(&d[16..32]);
Ok((key_select, schema, ts, nonce))
}
pub(crate) fn build_handshake_payload(
our_ip: [u8; 4],
our_port: u16,
peer_ip: [u8; 4],
peer_port: u16,
) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes.
p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16;
p[14..16].copy_from_slice(&pid.to_le_bytes());
let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes());
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
}
pub(crate) fn cbc_encrypt_padded(
key: &[u8; 32],
iv: &[u8; 16],
plaintext: &[u8],
) -> Result<(Vec<u8>, [u8; 16])> {
let pad = (16 - (plaintext.len() % 16)) % 16;
let mut buf = plaintext.to_vec();
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(*key, *iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?;
let mut new_iv = [0u8; 16];
if buf.len() >= 16 {
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
}
Ok((buf, new_iv))
}
pub(crate) fn cbc_decrypt_inplace(
key: &[u8; 32],
iv: &[u8; 16],
data: &mut [u8],
) -> Result<[u8; 16]> {
let mut new_iv = [0u8; 16];
if data.len() >= 16 {
new_iv.copy_from_slice(&data[data.len() - 16..]);
}
AesCbc::new(*key, *iv)
.decrypt_in_place(data)
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?;
Ok(new_iv)
}
pub(crate) struct RpcWriter {
pub(crate) writer: tokio::io::WriteHalf<tokio::net::TcpStream>,
pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32,
}
impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload);
self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16;
let mut buf = frame;
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(self.key, self.iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if buf.len() >= 16 {
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
}
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
}
pub(crate) async fn send_and_flush(&mut self, payload: &[u8]) -> Result<()> {
self.send(payload).await?;
self.writer.flush().await.map_err(ProxyError::Io)
}
}

View File

@@ -0,0 +1,183 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use httpdate;
use tracing::{debug, info, warn};
use crate::error::Result;
use super::MePool;
use super::secret::download_proxy_secret;
use crate::crypto::SecureRandom;
use std::time::SystemTime;
#[derive(Debug, Clone, Default)]
pub struct ProxyConfigData {
pub map: HashMap<i32, Vec<(IpAddr, u16)>>,
pub default_dc: Option<i32>,
}
fn parse_host_port(s: &str) -> Option<(IpAddr, u16)> {
if let Some(bracket_end) = s.rfind(']') {
if s.starts_with('[') && bracket_end + 1 < s.len() && s.as_bytes().get(bracket_end + 1) == Some(&b':') {
let host = &s[1..bracket_end];
let port_str = &s[bracket_end + 2..];
let ip = host.parse::<IpAddr>().ok()?;
let port = port_str.parse::<u16>().ok()?;
return Some((ip, port));
}
}
let idx = s.rfind(':')?;
let host = &s[..idx];
let port_str = &s[idx + 1..];
let ip = host.parse::<IpAddr>().ok()?;
let port = port_str.parse::<u16>().ok()?;
Some((ip, port))
}
fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> {
// Accepts lines like:
// proxy_for 4 91.108.4.195:8888;
// proxy_for 2 [2001:67c:04e8:f002::d]:80;
// proxy_for 2 2001:67c:04e8:f002::d:80;
let trimmed = line.trim();
if !trimmed.starts_with("proxy_for") {
return None;
}
// Capture everything between dc and trailing ';'
let without_prefix = trimmed.trim_start_matches("proxy_for").trim();
let mut parts = without_prefix.split_whitespace();
let dc_str = parts.next()?;
let rest = parts.next()?;
let host_port = rest.trim_end_matches(';');
let dc = dc_str.parse::<i32>().ok()?;
let (ip, port) = parse_host_port(host_port)?;
Some((dc, ip, port))
}
pub async fn fetch_proxy_config(url: &str) -> Result<ProxyConfigData> {
let resp = reqwest::get(url)
.await
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))?
;
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
if let Ok(date_str) = date.to_str() {
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
server_time.duration_since(SystemTime::now()).map_err(|_| e)
}) {
let skew_secs = skew.as_secs();
if skew_secs > 60 {
warn!(skew_secs, "Time skew >60s detected from fetch_proxy_config Date header");
} else if skew_secs > 30 {
warn!(skew_secs, "Time skew >30s detected from fetch_proxy_config Date header");
}
}
}
}
}
let text = resp
.text()
.await
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?;
let mut map: HashMap<i32, Vec<(IpAddr, u16)>> = HashMap::new();
for line in text.lines() {
if let Some((dc, ip, port)) = parse_proxy_line(line) {
map.entry(dc).or_default().push((ip, port));
}
}
let default_dc = text
.lines()
.find_map(|l| {
let t = l.trim();
if let Some(rest) = t.strip_prefix("default") {
return rest
.trim()
.trim_end_matches(';')
.parse::<i32>()
.ok();
}
None
});
Ok(ProxyConfigData { map, default_dc })
}
pub async fn me_config_updater(pool: Arc<MePool>, rng: Arc<SecureRandom>, interval: Duration) {
let mut tick = tokio::time::interval(interval);
// skip immediate tick to avoid double-fetch right after startup
tick.tick().await;
loop {
tick.tick().await;
// Update proxy config v4
if let Ok(cfg) = fetch_proxy_config("https://core.telegram.org/getProxyConfig").await {
let changed = pool.update_proxy_maps(cfg.map.clone(), None).await;
if let Some(dc) = cfg.default_dc {
pool.default_dc.store(dc, std::sync::atomic::Ordering::Relaxed);
}
if changed {
info!("ME config updated (v4), reconciling connections");
pool.reconcile_connections(&rng).await;
} else {
debug!("ME config v4 unchanged");
}
} else {
warn!("getProxyConfig update failed");
}
// Update proxy config v6 (optional)
if let Ok(cfg_v6) = fetch_proxy_config("https://core.telegram.org/getProxyConfigV6").await {
let _ = pool.update_proxy_maps(HashMap::new(), Some(cfg_v6.map)).await;
}
// Update proxy-secret
match download_proxy_secret().await {
Ok(secret) => {
if pool.update_secret(secret).await {
info!("proxy-secret updated and pool reconnect scheduled");
}
}
Err(e) => warn!(error = %e, "proxy-secret update failed"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_ipv6_bracketed() {
let line = "proxy_for 2 [2001:67c:04e8:f002::d]:80;";
let res = parse_proxy_line(line).unwrap();
assert_eq!(res.0, 2);
assert_eq!(res.1, "2001:67c:04e8:f002::d".parse::<IpAddr>().unwrap());
assert_eq!(res.2, 80);
}
#[test]
fn parse_ipv6_plain() {
let line = "proxy_for 2 2001:67c:04e8:f002::d:80;";
let res = parse_proxy_line(line).unwrap();
assert_eq!(res.0, 2);
assert_eq!(res.1, "2001:67c:04e8:f002::d".parse::<IpAddr>().unwrap());
assert_eq!(res.2, 80);
}
#[test]
fn parse_ipv4() {
let line = "proxy_for 4 91.108.4.195:8888;";
let res = parse_proxy_line(line).unwrap();
assert_eq!(res.0, 4);
assert_eq!(res.1, "91.108.4.195".parse::<IpAddr>().unwrap());
assert_eq!(res.2, 8888);
}
}

View File

@@ -0,0 +1,439 @@
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use socket2::{SockRef, TcpKeepalive};
#[cfg(target_os = "linux")]
use libc;
#[cfg(target_os = "linux")]
use std::os::fd::{AsRawFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::raw::c_int;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::{TcpStream, TcpSocket};
use tokio::time::timeout;
use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32,
};
use super::codec::{
build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext,
};
use super::wire::{extract_ip_material, IpMaterial};
use super::MePool;
/// Result of a successful ME handshake with timings.
pub(crate) struct HandshakeOutput {
pub rd: ReadHalf<TcpStream>,
pub wr: WriteHalf<TcpStream>,
pub read_key: [u8; 32],
pub read_iv: [u8; 16],
pub write_key: [u8; 32],
pub write_iv: [u8; 16],
pub handshake_ms: f64,
}
impl MePool {
/// TCP connect with timeout + return RTT in milliseconds.
pub(crate) async fn connect_tcp(&self, addr: SocketAddr) -> Result<(TcpStream, f64)> {
let start = Instant::now();
let connect_fut = async {
if addr.is_ipv6() {
if let Some(v6) = self.detected_ipv6 {
match TcpSocket::new_v6() {
Ok(sock) => {
if let Err(e) = sock.bind(SocketAddr::new(IpAddr::V6(v6), 0)) {
debug!(error = %e, bind_ip = %v6, "ME IPv6 bind failed, falling back to default bind");
} else {
match sock.connect(addr).await {
Ok(stream) => return Ok(stream),
Err(e) => debug!(error = %e, target = %addr, "ME IPv6 bound connect failed, retrying default connect"),
}
}
}
Err(e) => debug!(error = %e, "ME IPv6 socket creation failed, falling back to default connect"),
}
}
}
TcpStream::connect(addr).await
};
let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), connect_fut)
.await
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??;
let connect_ms = start.elapsed().as_secs_f64() * 1000.0;
stream.set_nodelay(true).ok();
if let Err(e) = Self::configure_keepalive(&stream) {
warn!(error = %e, "ME keepalive setup failed");
}
#[cfg(target_os = "linux")]
if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) {
warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed");
}
Ok((stream, connect_ms))
}
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
let sock = SockRef::from(stream);
let ka = TcpKeepalive::new()
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10))
.with_retries(3);
sock.set_tcp_keepalive(&ka)?;
sock.set_keepalive(true)?;
Ok(())
}
#[cfg(target_os = "linux")]
fn configure_user_timeout(fd: RawFd) -> std::io::Result<()> {
let timeout_ms: c_int = 30_000;
let rc = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
std::mem::size_of_val(&timeout_ms) as libc::socklen_t,
)
};
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
/// Perform full ME RPC handshake on an established TCP stream.
/// Returns cipher keys/ivs and split halves; does not register writer.
pub(crate) async fn handshake_only(
&self,
stream: TcpStream,
addr: SocketAddr,
rng: &SecureRandom,
) -> Result<HandshakeOutput> {
let hs_start = Instant::now();
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
let family = if local_addr.ip().is_ipv4() {
IpFamily::V4
} else {
IpFamily::V6
};
let reflected = if self.nat_probe {
self.maybe_reflect_public_addr(family).await
} else {
None
};
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
let (mut rd, mut wr) = tokio::io::split(stream);
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
let crypto_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
let ks = self.key_selector().await;
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
debug!(
key_selector = format_args!("0x{ks:08x}"),
crypto_ts,
frame_len = nonce_frame.len(),
nonce_frame_hex = %dump,
"Sending ME nonce frame"
);
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let (srv_seq, srv_nonce_payload) = timeout(
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
read_rpc_frame_plaintext(&mut rd),
)
.await
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
if srv_seq != -2 {
return Err(ProxyError::InvalidHandshake(format!("Expected seq=-2, got {srv_seq}")));
}
let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
if schema != RPC_CRYPTO_AES_U32 {
warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema");
return Err(ProxyError::InvalidHandshake(format!(
"Unsupported crypto schema: 0x{schema:x}"
)));
}
if srv_key_select != ks {
return Err(ProxyError::InvalidHandshake(format!(
"Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}"
)));
}
let skew = crypto_ts.abs_diff(srv_ts);
if skew > 30 {
return Err(ProxyError::InvalidHandshake(format!(
"nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s"
)));
}
info!(
%local_addr,
%local_addr_nat,
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
%peer_addr,
%peer_addr_nat,
key_selector = format_args!("0x{ks:08x}"),
crypto_schema = format_args!("0x{schema:08x}"),
skew_secs = skew,
"ME key derivation parameters"
);
let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let client_port_bytes = local_addr_nat.port().to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat);
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) {
(IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => {
srv.reverse();
clt.reverse();
(Some(srv), Some(clt), None, None, clt, srv)
}
(IpMaterial::V6(srv), IpMaterial::V6(clt)) => {
let zero = [0u8; 4];
(None, None, Some(clt), Some(srv), zero, zero)
}
_ => {
return Err(ProxyError::InvalidHandshake(
"mixed IPv4/IPv6 endpoints are not supported for ME key derivation".to_string(),
));
}
};
let diag_level: u8 = std::env::var("ME_DIAG").ok().and_then(|v| v.parse().ok()).unwrap_or(0);
let secret: Vec<u8> = self.proxy_secret.read().await.clone();
let prekey_client = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let prekey_server = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (rk, ri) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port());
let hs_frame = build_rpc_frame(-1, &hs_payload);
if diag_level >= 1 {
info!(
write_key = %hex_dump(&wk),
write_iv = %hex_dump(&wi),
read_key = %hex_dump(&rk),
read_iv = %hex_dump(&ri),
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
srv_port = %hex_dump(&server_port_bytes),
clt_port = %hex_dump(&client_port_bytes),
crypto_ts = %hex_dump(&ts_bytes),
nonce_srv = %hex_dump(&srv_nonce),
nonce_clt = %hex_dump(&my_nonce),
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
hs_plain = %hex_dump(&hs_frame),
proxy_secret_sha256 = %hex_dump(&sha256(&secret)),
"ME diag: derived keys and handshake plaintext"
);
}
if diag_level >= 2 {
info!(
prekey_client = %hex_dump(&prekey_client),
prekey_server = %hex_dump(&prekey_server),
"ME diag: full prekey buffers"
);
}
let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 {
info!(
hs_cipher = %hex_dump(&encrypted_hs),
"ME diag: handshake ciphertext"
);
}
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri;
let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok {
let remaining = deadline - Instant::now();
let mut tmp = [0u8; 256];
let n = match timeout(remaining, rd.read(&mut tmp)).await {
Ok(Ok(0)) => {
return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"ME closed during handshake",
)));
}
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
};
enc_buf.extend_from_slice(&tmp[..n]);
let blocks = enc_buf.len() / 16 * 16;
if blocks > 0 {
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&enc_buf[..blocks]);
read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
dec_buf.extend_from_slice(&chunk);
let _ = enc_buf.split_to(blocks);
}
while dec_buf.len() >= 4 {
let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec_buf.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad HS response frame len: {fl}"
)));
}
if dec_buf.len() < fl {
break;
}
let frame = dec_buf.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]);
if ec != ac {
return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
)));
}
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 {
i32::from_le_bytes(frame[12..16].try_into().unwrap())
} else {
-1
};
return Err(ProxyError::InvalidHandshake(format!(
"ME rejected handshake (error={err_code})"
)));
}
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
handshake_ok = true;
break;
}
}
if !handshake_ok {
return Err(ProxyError::TgHandshakeTimeout);
}
let handshake_ms = hs_start.elapsed().as_secs_f64() * 1000.0;
info!(%addr, "RPC handshake OK");
Ok(HandshakeOutput {
rd,
wr,
read_key: rk,
read_iv,
write_key: wk,
write_iv,
handshake_ms,
})
}
}
fn hex_dump(data: &[u8]) -> String {
const MAX: usize = 64;
let mut out = String::with_capacity(data.len() * 2 + 3);
for (i, b) in data.iter().take(MAX).enumerate() {
if i > 0 {
out.push(' ');
}
out.push_str(&format!("{b:02x}"));
}
if data.len() > MAX {
out.push_str("");
}
out
}

View File

@@ -0,0 +1,174 @@
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
use rand::seq::SliceRandom;
use crate::crypto::SecureRandom;
use crate::network::IpFamily;
use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) {
let mut backoff: HashMap<(i32, IpFamily), u64> = HashMap::new();
let mut last_attempt: HashMap<(i32, IpFamily), Instant> = HashMap::new();
let mut inflight_single: HashSet<(i32, IpFamily)> = HashSet::new();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
check_family(
IpFamily::V4,
&pool,
&rng,
&mut backoff,
&mut last_attempt,
&mut inflight_single,
)
.await;
check_family(
IpFamily::V6,
&pool,
&rng,
&mut backoff,
&mut last_attempt,
&mut inflight_single,
)
.await;
}
}
async fn check_family(
family: IpFamily,
pool: &Arc<MePool>,
rng: &Arc<SecureRandom>,
backoff: &mut HashMap<(i32, IpFamily), u64>,
last_attempt: &mut HashMap<(i32, IpFamily), Instant>,
inflight_single: &mut HashSet<(i32, IpFamily)>,
) {
let enabled = match family {
IpFamily::V4 => pool.decision.ipv4_me,
IpFamily::V6 => pool.decision.ipv6_me,
};
if !enabled {
return;
}
let map = match family {
IpFamily::V4 => pool.proxy_map_v4.read().await.clone(),
IpFamily::V6 => pool.proxy_map_v6.read().await.clone(),
};
let writer_addrs: HashSet<SocketAddr> = pool
.writers
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
let entries: Vec<(i32, Vec<SocketAddr>)> = map
.iter()
.map(|(dc, addrs)| {
let list = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect::<Vec<_>>();
(*dc, list)
})
.collect();
for (dc, dc_addrs) in entries {
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
if has_coverage {
inflight_single.remove(&(dc, family));
continue;
}
let key = (dc, family);
let delay = *backoff.get(&key).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(&key) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
if dc_addrs.len() == 1 {
// Single ME address: fast retries then slower background retries.
if inflight_single.contains(&key) {
continue;
}
inflight_single.insert(key);
let addr = dc_addrs[0];
let dc_id = dc;
let pool_clone = pool.clone();
let rng_clone = rng.clone();
let timeout = pool.me_one_timeout;
let quick_attempts = pool.me_one_retry.max(1);
tokio::spawn(async move {
let mut success = false;
for _ in 0..quick_attempts {
let res = tokio::time::timeout(timeout, pool_clone.connect_one(addr, rng_clone.as_ref())).await;
match res {
Ok(Ok(())) => {
info!(%addr, dc = %dc_id, ?family, "ME reconnected for DC coverage");
success = true;
break;
}
Ok(Err(e)) => debug!(%addr, dc = %dc_id, error = %e, ?family, "ME reconnect failed"),
Err(_) => debug!(%addr, dc = %dc_id, ?family, "ME reconnect timed out"),
}
tokio::time::sleep(Duration::from_millis(1000)).await;
}
if success {
return;
}
let timeout_ms = timeout.as_millis();
warn!(
dc = %dc_id,
?family,
attempts = quick_attempts,
timeout_ms,
"DC={} has no ME coverage: {} tries * {} ms... retry in 5 seconds...",
dc_id,
quick_attempts,
timeout_ms
);
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
let res = tokio::time::timeout(timeout, pool_clone.connect_one(addr, rng_clone.as_ref())).await;
match res {
Ok(Ok(())) => {
info!(%addr, dc = %dc_id, ?family, "ME reconnected for DC coverage");
break;
}
Ok(Err(e)) => debug!(%addr, dc = %dc_id, error = %e, ?family, "ME reconnect failed"),
Err(_) => debug!(%addr, dc = %dc_id, ?family, "ME reconnect timed out"),
}
}
// will drop inflight flag in outer loop when coverage detected
});
continue;
}
warn!(dc = %dc, delay, ?family, "DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, rng.as_ref()).await {
Ok(()) => {
info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage");
backoff.insert(key, 30);
last_attempt.insert(key, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed"),
}
}
if !reconnected {
let next = (*backoff.get(&key).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(key, next);
last_attempt.insert(key, now);
}
}
}

View File

@@ -0,0 +1,34 @@
//! Middle Proxy RPC transport.
mod codec;
mod handshake;
mod health;
mod pool;
mod pool_nat;
mod ping;
mod reader;
mod registry;
mod send;
mod secret;
mod rotation;
mod config_updater;
mod wire;
use bytes::Bytes;
pub use health::me_health_monitor;
pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily};
pub use pool::MePool;
pub use pool_nat::{stun_probe, detect_public_ip};
pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret;
pub use config_updater::{fetch_proxy_config, me_config_updater};
pub use rotation::me_rotation_task;
pub use wire::proto_flags_for_tag;
#[derive(Debug)]
pub enum MeResponse {
Data { flags: u32, data: Bytes },
Ack(u32),
Close,
}

View File

@@ -0,0 +1,173 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use crate::crypto::SecureRandom;
use crate::error::ProxyError;
use super::MePool;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MePingFamily {
V4,
V6,
}
#[derive(Debug, Clone)]
pub struct MePingSample {
pub dc: i32,
pub addr: SocketAddr,
pub connect_ms: Option<f64>,
pub handshake_ms: Option<f64>,
pub error: Option<String>,
pub family: MePingFamily,
}
#[derive(Debug, Clone)]
pub struct MePingReport {
pub dc: i32,
pub family: MePingFamily,
pub samples: Vec<MePingSample>,
}
pub fn format_sample_line(sample: &MePingSample) -> String {
let sign = if sample.dc >= 0 { "+" } else { "-" };
let addr = format!("{}:{}", sample.addr.ip(), sample.addr.port());
match (sample.connect_ms, sample.handshake_ms.as_ref(), sample.error.as_ref()) {
(Some(conn), Some(hs), None) => format!(
" {sign} {addr}\tPing: {:.0} ms / RPC: {:.0} ms / OK",
conn, hs
),
(Some(conn), None, Some(err)) => format!(
" {sign} {addr}\tPing: {:.0} ms / RPC: FAIL ({err})",
conn
),
(None, _, Some(err)) => format!(" {sign} {addr}\tPing: FAIL ({err})"),
(Some(conn), None, None) => format!(" {sign} {addr}\tPing: {:.0} ms / RPC: FAIL", conn),
_ => format!(" {sign} {addr}\tPing: FAIL"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
fn sample(base: MePingSample) -> MePingSample {
base
}
#[test]
fn ok_line_contains_both_timings() {
let s = sample(MePingSample {
dc: 4,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 8888),
connect_ms: Some(12.3),
handshake_ms: Some(34.7),
error: None,
family: MePingFamily::V4,
});
let line = format_sample_line(&s);
assert!(line.contains("Ping: 12 ms"));
assert!(line.contains("RPC: 35 ms"));
assert!(line.contains("OK"));
}
#[test]
fn error_line_mentions_reason() {
let s = sample(MePingSample {
dc: -5,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)), 80),
connect_ms: Some(10.0),
handshake_ms: None,
error: Some("handshake timeout".to_string()),
family: MePingFamily::V4,
});
let line = format_sample_line(&s);
assert!(line.contains("- 5.6.7.8:80"));
assert!(line.contains("handshake timeout"));
}
}
pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingReport> {
let mut reports = Vec::new();
let v4_map = if pool.decision.ipv4_me {
pool.proxy_map_v4.read().await.clone()
} else {
HashMap::new()
};
let v6_map = if pool.decision.ipv6_me {
pool.proxy_map_v6.read().await.clone()
} else {
HashMap::new()
};
let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new();
for (dc, addrs) in v4_map {
grouped.push((MePingFamily::V4, dc, addrs));
}
for (dc, addrs) in v6_map {
grouped.push((MePingFamily::V6, dc, addrs));
}
for (family, dc, addrs) in grouped {
let mut samples = Vec::new();
for (ip, port) in addrs {
let addr = SocketAddr::new(ip, port);
let mut connect_ms = None;
let mut handshake_ms = None;
let mut error = None;
match pool.connect_tcp(addr).await {
Ok((stream, conn_rtt)) => {
connect_ms = Some(conn_rtt);
match pool.handshake_only(stream, addr, rng).await {
Ok(hs) => {
handshake_ms = Some(hs.handshake_ms);
// drop halves to close
drop(hs.rd);
drop(hs.wr);
}
Err(e) => {
error = Some(short_err(&e));
}
}
}
Err(e) => {
error = Some(short_err(&e));
}
}
samples.push(MePingSample {
dc,
addr,
connect_ms,
handshake_ms,
error,
family,
});
}
reports.push(MePingReport {
dc,
family,
samples,
});
}
reports
}
fn short_err(err: &ProxyError) -> String {
match err {
ProxyError::ConnectionTimeout { .. } => "connect timeout".to_string(),
ProxyError::TgHandshakeTimeout => "handshake timeout".to_string(),
ProxyError::InvalidHandshake(e) => format!("bad handshake: {e}"),
ProxyError::Crypto(e) => format!("crypto: {e}"),
ProxyError::Proxy(e) => format!("proxy: {e}"),
ProxyError::Io(e) => format!("io: {e}"),
_ => format!("{err}"),
}
}

View File

@@ -0,0 +1,491 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
use bytes::BytesMut;
use rand::Rng;
use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use std::time::Duration;
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
use crate::network::probe::NetworkDecision;
use crate::network::IpFamily;
use crate::protocol::constants::*;
use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta};
use super::codec::RpcWriter;
use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
#[derive(Clone)]
pub struct MeWriter {
pub id: u64,
pub addr: SocketAddr,
pub writer: Arc<Mutex<RpcWriter>>,
pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>,
pub draining: Arc<AtomicBool>,
}
pub struct MePool {
pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64,
pub(super) decision: NetworkDecision,
pub(super) rng: Arc<SecureRandom>,
pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
pub(super) nat_ip_cfg: Option<IpAddr>,
pub(super) nat_ip_detected: Arc<RwLock<Option<IpAddr>>>,
pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>,
pub(super) detected_ipv6: Option<Ipv6Addr>,
pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8,
pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool,
pub(super) me_one_retry: u8,
pub(super) me_one_timeout: Duration,
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) proxy_map_v6: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) default_dc: AtomicI32,
pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pool_size: usize,
}
#[derive(Debug, Default)]
pub struct NatReflectionCache {
pub v4: Option<(std::time::Instant, std::net::SocketAddr)>,
pub v6: Option<(std::time::Instant, std::net::SocketAddr)>,
}
impl MePool {
pub fn new(
proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
nat_ip: Option<IpAddr>,
nat_probe: bool,
nat_stun: Option<String>,
detected_ipv6: Option<Ipv6Addr>,
me_one_retry: u8,
me_one_timeout_ms: u64,
proxy_map_v4: HashMap<i32, Vec<(IpAddr, u16)>>,
proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>,
default_dc: Option<i32>,
decision: NetworkDecision,
rng: Arc<SecureRandom>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
decision,
rng,
proxy_tag,
proxy_secret: Arc::new(RwLock::new(proxy_secret)),
nat_ip_cfg: nat_ip,
nat_ip_detected: Arc::new(RwLock::new(None)),
nat_probe,
nat_stun,
detected_ipv6,
nat_probe_attempts: std::sync::atomic::AtomicU8::new(0),
nat_probe_disabled: std::sync::atomic::AtomicBool::new(false),
me_one_retry,
me_one_timeout: Duration::from_millis(me_one_timeout_ms),
pool_size: 2,
proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)),
proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)),
default_dc: AtomicI32::new(default_dc.unwrap_or(0)),
next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
})
}
pub fn has_proxy_tag(&self) -> bool {
self.proxy_tag.is_some()
}
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
let ip = self.translate_ip_for_nat(addr.ip());
SocketAddr::new(ip, addr.port())
}
pub fn registry(&self) -> &Arc<ConnRegistry> {
&self.registry
}
fn writers_arc(&self) -> Arc<RwLock<Vec<MeWriter>>> {
self.writers.clone()
}
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet;
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers);
for family in self.family_order() {
let map = self.proxy_map_for_family(family).await;
for (_dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
if !dc_addrs.iter().any(|a| current.contains(a)) {
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
}
}
}
}
if !self.decision.effective_multipath && !current.is_empty() {
break;
}
}
}
pub async fn update_proxy_maps(
&self,
new_v4: HashMap<i32, Vec<(IpAddr, u16)>>,
new_v6: Option<HashMap<i32, Vec<(IpAddr, u16)>>>,
) -> bool {
let mut changed = false;
{
let mut guard = self.proxy_map_v4.write().await;
if !new_v4.is_empty() && *guard != new_v4 {
*guard = new_v4;
changed = true;
}
}
if let Some(v6) = new_v6 {
let mut guard = self.proxy_map_v6.write().await;
if !v6.is_empty() && *guard != v6 {
*guard = v6;
}
}
changed
}
pub async fn update_secret(&self, new_secret: Vec<u8>) -> bool {
if new_secret.len() < 32 {
warn!(len = new_secret.len(), "proxy-secret update ignored (too short)");
return false;
}
let mut guard = self.proxy_secret.write().await;
if *guard != new_secret {
*guard = new_secret;
drop(guard);
self.reconnect_all().await;
return true;
}
false
}
pub async fn reconnect_all(&self) {
// Graceful: do not drop all at once. New connections will use updated secret.
// Existing writers remain until health monitor replaces them.
// No-op here to avoid total outage.
}
pub(super) async fn key_selector(&self) -> u32 {
let secret = self.proxy_secret.read().await;
if secret.len() >= 4 {
u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]])
} else {
0
}
}
pub(super) fn family_order(&self) -> Vec<IpFamily> {
let mut order = Vec::new();
if self.decision.prefer_ipv6() {
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
} else {
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
}
order
}
async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap<i32, Vec<(IpAddr, u16)>> {
match family {
IpFamily::V4 => self.proxy_map_v4.read().await.clone(),
IpFamily::V6 => self.proxy_map_v6.read().await.clone(),
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let family_order = self.family_order();
let ks = self.key_selector().await;
info!(
me_servers = self.proxy_map_v4.read().await.len(),
pool_size,
key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.len(),
"Initializing ME pool"
);
for family in family_order {
let map = self.proxy_map_for_family(family).await;
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
let mut dc_failures = 0usize;
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await
});
}
while let Some(res) = join.join_next().await {
if let Ok(false) = res {
dc_failures += 1;
}
}
if dc_failures > 2 {
return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into()));
}
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size {
break;
}
}
if !self.decision.effective_multipath && self.connection_count() > 0 {
break;
}
}
if self.writers.read().await.is_empty() {
return Err(ProxyError::Proxy("No ME connections".into()));
}
Ok(())
}
pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
let secret_len = self.proxy_secret.read().await.len();
if secret_len < 32 {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
}
let (stream, _connect_ms) = self.connect_tcp(addr).await?;
let hs = self.handshake_only(stream, addr, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false));
let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: hs.wr,
key: hs.write_key,
iv: hs.write_iv,
seq_no: 0,
}));
let writer = MeWriter {
id: writer_id,
addr,
writer: rpc_w.clone(),
cancel: cancel.clone(),
degraded: degraded.clone(),
draining: draining.clone(),
};
self.writers.write().await.push(writer.clone());
let reg = self.registry.clone();
let writers_arc = self.writers_arc();
let ping_tracker = self.ping_tracker.clone();
let rtt_stats = self.rtt_stats.clone();
let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone();
let rpc_w_ping = rpc_w.clone();
let ping_tracker_ping = ping_tracker.clone();
tokio::spawn(async move {
let cancel_reader = cancel.clone();
let res = reader_loop(
hs.rd,
hs.read_key,
hs.read_iv,
reg.clone(),
BytesMut::new(),
BytesMut::new(),
rpc_w.clone(),
ping_tracker.clone(),
rtt_stats.clone(),
writer_id,
degraded.clone(),
cancel_reader.clone(),
)
.await;
if let Some(pool) = pool.upgrade() {
pool.remove_writer_and_close_clients(writer_id).await;
}
if let Err(e) = res {
warn!(error = %e, "ME reader ended");
}
let mut ws = writers_arc.write().await;
ws.retain(|w| w.id != writer_id);
info!(remaining = ws.len(), "Dead ME writer removed from pool");
});
let pool_ping = Arc::downgrade(self);
tokio::spawn(async move {
let mut ping_id: i64 = rand::random::<i64>();
loop {
let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
tokio::select! {
_ = cancel_ping.cancelled() => {
break;
}
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
}
let sent_id = ping_id;
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
cancel_ping.cancel();
if let Some(pool) = pool_ping.upgrade() {
pool.remove_writer_and_close_clients(writer_id).await;
}
break;
}
}
});
Ok(())
}
async fn connect_primary_for_dc(
self: Arc<Self>,
dc: i32,
mut addrs: Vec<(IpAddr, u16)>,
rng: Arc<SecureRandom>,
) -> bool {
if addrs.is_empty() {
return false;
}
addrs.shuffle(&mut rand::rng());
for (ip, port) in addrs {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng.as_ref()).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME connected");
return true;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
}
}
warn!(dc = %dc, "All ME servers for DC failed at init");
false
}
pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) {
let conns = self.remove_writer_only(writer_id).await;
for bound in conns {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
let _ = self.registry.unregister(bound.conn_id).await;
}
}
async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> {
{
let mut ws = self.writers.write().await;
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos);
w.cancel.cancel();
}
}
self.registry.writer_lost(writer_id).await
}
pub(crate) async fn mark_writer_draining(self: &Arc<Self>, writer_id: u64) {
{
let mut ws = self.writers.write().await;
if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) {
w.draining.store(true, Ordering::Relaxed);
}
}
let pool = Arc::downgrade(self);
tokio::spawn(async move {
loop {
if let Some(p) = pool.upgrade() {
if p.registry.is_writer_empty(writer_id).await {
let _ = p.remove_writer_only(writer_id).await;
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
} else {
break;
}
}
});
}
}
fn hex_dump(data: &[u8]) -> String {
const MAX: usize = 64;
let mut out = String::with_capacity(data.len() * 2 + 3);
for (i, b) in data.iter().take(MAX).enumerate() {
if i > 0 {
out.push(' ');
}
out.push_str(&format!("{b:02x}"));
}
if data.len() > MAX {
out.push_str("");
}
out
}

View File

@@ -0,0 +1,196 @@
use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
use tracing::{info, warn, debug};
use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon;
use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult};
use super::MePool;
use std::time::Instant;
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
stun_probe_dual(&stun_addr).await
}
pub async fn detect_public_ip() -> Option<IpAddr> {
fetch_public_ipv4_with_retry().await.ok().flatten().map(IpAddr::V4)
}
impl MePool {
pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
let nat_ip = self
.nat_ip_cfg
.or_else(|| self.nat_ip_detected.try_read().ok().and_then(|g| (*g).clone()));
let Some(nat_ip) = nat_ip else {
return ip;
};
match (ip, nat_ip) {
(IpAddr::V4(src), IpAddr::V4(dst))
if is_bogon(IpAddr::V4(src))
|| src.is_loopback()
|| src.is_unspecified() =>
{
IpAddr::V4(dst)
}
(IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => {
IpAddr::V6(dst)
}
(orig, _) => orig,
}
}
pub(super) fn translate_our_addr_with_reflection(
&self,
addr: std::net::SocketAddr,
reflected: Option<std::net::SocketAddr>,
) -> std::net::SocketAddr {
let ip = if let Some(r) = reflected {
// Use reflected IP (not port) only when local address is non-public.
if is_bogon(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
r.ip()
} else {
self.translate_ip_for_nat(addr.ip())
}
} else {
self.translate_ip_for_nat(addr.ip())
};
// Keep the kernel-assigned TCP source port; STUN port can differ.
std::net::SocketAddr::new(ip, addr.port())
}
pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option<IpAddr> {
if self.nat_ip_cfg.is_some() {
return self.nat_ip_cfg;
}
if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None;
}
if let Some(ip) = self.nat_ip_detected.read().await.clone() {
return Some(ip);
}
match fetch_public_ipv4_with_retry().await {
Ok(Some(ip)) => {
{
let mut guard = self.nat_ip_detected.write().await;
*guard = Some(IpAddr::V4(ip));
}
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
Ok(None) => None,
Err(e) => {
warn!(error = %e, "Failed to auto-detect public IP");
None
}
}
}
pub(super) async fn maybe_reflect_public_addr(
&self,
family: IpFamily,
) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
// If STUN probing was disabled after attempts, reuse cached (even stale) or skip.
if self.nat_probe_disabled.load(std::sync::atomic::Ordering::Relaxed) {
if let Ok(cache) = self.nat_reflection_cache.try_lock() {
let slot = match family {
IpFamily::V4 => cache.v4,
IpFamily::V6 => cache.v6,
};
return slot.map(|(_, addr)| addr);
}
return None;
}
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
if let Some((ts, addr)) = slot {
if ts.elapsed() < STUN_CACHE_TTL {
return Some(*addr);
}
}
}
let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if attempt >= 2 {
self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed);
return None;
}
let stun_addr = self
.nat_stun
.clone()
.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
match stun_probe_dual(&stun_addr).await {
Ok(res) => {
let picked: Option<StunProbeResult> = match family {
IpFamily::V4 => res.v4,
IpFamily::V6 => res.v6,
};
if let Some(result) = picked {
info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address");
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
*slot = Some((Instant::now(), result.reflected_addr));
}
Some(result.reflected_addr)
} else {
None
}
}
Err(e) => {
let attempts = attempt + 1;
if attempts <= 2 {
warn!(error = %e, attempt = attempts, "NAT probe failed");
} else {
debug!(error = %e, attempt = attempts, "NAT probe suppressed after max attempts");
}
if attempts >= 2 {
self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed);
}
None
}
}
}
}
async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> {
let providers = [
"https://checkip.amazonaws.com",
"http://v4.ident.me",
"http://ipv4.icanhazip.com",
];
for url in providers {
if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await {
return Ok(Some(ip));
}
}
Ok(None)
}
async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
let res = reqwest::get(url).await.map_err(|e| {
ProxyError::Proxy(format!("public IP detection request failed: {e}"))
})?;
let text = res.text().await.map_err(|e| {
ProxyError::Proxy(format!("public IP detection read failed: {e}"))
})?;
let ip = text.trim().parse().ok();
Ok(ip)
}

View File

@@ -0,0 +1,182 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::codec::RpcWriter;
use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
reg: Arc<ConnRegistry>,
enc_leftover: BytesMut,
mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
_writer_id: u64,
degraded: Arc<AtomicBool>,
cancel: CancellationToken,
) -> Result<()> {
let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
loop {
let mut tmp = [0u8; 16_384];
let n = tokio::select! {
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
_ = cancel.cancelled() => return Ok(()),
};
if n == 0 {
return Ok(());
}
raw.extend_from_slice(&tmp[..n]);
let blocks = raw.len() / 16 * 16;
if blocks > 0 {
let mut new_iv = [0u8; 16];
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&raw[..blocks]);
AesCbc::new(dk, div)
.decrypt_in_place(&mut chunk)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
div = new_iv;
dec.extend_from_slice(&chunk);
let _ = raw.split_to(blocks);
}
while dec.len() >= 12 {
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
warn!(frame_len = fl, "Invalid RPC frame len");
dec.clear();
break;
}
if dec.len() < fl {
break;
}
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame");
continue;
}
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq {
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
expected_seq = seq_no.wrapping_add(1);
} else {
expected_seq = expected_seq.wrapping_add(1);
}
let payload = &frame[8..pe];
if payload.len() < 4 {
continue;
}
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
let body = &payload[4..];
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
let data = Bytes::copy_from_slice(&body[12..]);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_EXT from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> RPC_PONG");
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
if let Err(e) = writer.lock().await.send_and_flush(&pong).await {
warn!(error = %e, "PONG send failed");
break;
}
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id)
} {
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
let mut stats = rtt_stats.lock().await;
let entry = stats.entry(wid).or_insert((rtt, rtt));
entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 {
entry.0 = rtt;
} else {
// allow slow baseline drift upward to avoid stale minimum
entry.0 = entry.0 * 0.99 + rtt * 0.01;
}
let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed);
trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample");
}
} else {
debug!(
rpc_type = format_args!("0x{pt:08x}"),
len = body.len(),
"Unknown RPC"
);
}
}
}
}
async fn send_close_conn(writer: &Arc<Mutex<RpcWriter>>, conn_id: u64) {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = writer.lock().await.send_and_flush(&p).await {
debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN");
}
}

View File

@@ -0,0 +1,156 @@
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, RwLock};
use super::codec::RpcWriter;
use super::MeResponse;
#[derive(Clone)]
pub struct ConnMeta {
pub target_dc: i16,
pub client_addr: SocketAddr,
pub our_addr: SocketAddr,
pub proto_flags: u32,
}
#[derive(Clone)]
pub struct BoundConn {
pub conn_id: u64,
pub meta: ConnMeta,
}
#[derive(Clone)]
pub struct ConnWriter {
pub writer_id: u64,
pub writer: Arc<Mutex<RpcWriter>>,
}
struct RegistryInner {
map: HashMap<u64, mpsc::Sender<MeResponse>>,
writers: HashMap<u64, Arc<Mutex<RpcWriter>>>,
writer_for_conn: HashMap<u64, u64>,
conns_for_writer: HashMap<u64, HashSet<u64>>,
meta: HashMap<u64, ConnMeta>,
}
impl RegistryInner {
fn new() -> Self {
Self {
map: HashMap::new(),
writers: HashMap::new(),
writer_for_conn: HashMap::new(),
conns_for_writer: HashMap::new(),
meta: HashMap::new(),
}
}
}
pub struct ConnRegistry {
inner: RwLock<RegistryInner>,
next_id: AtomicU64,
}
impl ConnRegistry {
pub fn new() -> Self {
let start = rand::random::<u64>() | 1;
Self {
inner: RwLock::new(RegistryInner::new()),
next_id: AtomicU64::new(start),
}
}
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(1024);
self.inner.write().await.map.insert(id, tx);
(id, rx)
}
/// Unregister connection, returning associated writer_id if any.
pub async fn unregister(&self, id: u64) -> Option<u64> {
let mut inner = self.inner.write().await;
inner.map.remove(&id);
inner.meta.remove(&id);
if let Some(writer_id) = inner.writer_for_conn.remove(&id) {
if let Some(set) = inner.conns_for_writer.get_mut(&writer_id) {
set.remove(&id);
}
return Some(writer_id);
}
None
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
let inner = self.inner.read().await;
if let Some(tx) = inner.map.get(&id) {
tx.try_send(resp).is_ok()
} else {
false
}
}
pub async fn bind_writer(
&self,
conn_id: u64,
writer_id: u64,
writer: Arc<Mutex<RpcWriter>>,
meta: ConnMeta,
) {
let mut inner = self.inner.write().await;
inner.meta.entry(conn_id).or_insert(meta);
inner.writer_for_conn.insert(conn_id, writer_id);
inner.writers.entry(writer_id).or_insert_with(|| writer.clone());
inner
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new)
.insert(conn_id);
}
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
let inner = self.inner.read().await;
let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?;
let writer = inner.writers.get(&writer_id).cloned()?;
Some(ConnWriter { writer_id, writer })
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
let mut inner = self.inner.write().await;
inner.writers.remove(&writer_id);
let conns = inner
.conns_for_writer
.remove(&writer_id)
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
let mut out = Vec::new();
for conn_id in conns {
inner.writer_for_conn.remove(&conn_id);
if let Some(m) = inner.meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
}
}
out
}
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
let inner = self.inner.read().await;
inner.meta.get(&conn_id).cloned()
}
pub async fn is_writer_empty(&self, writer_id: u64) -> bool {
let inner = self.inner.read().await;
inner
.conns_for_writer
.get(&writer_id)
.map(|s| s.is_empty())
.unwrap_or(true)
}
}

View File

@@ -0,0 +1,42 @@
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tracing::{info, warn};
use crate::crypto::SecureRandom;
use super::MePool;
/// Periodically refresh ME connections to avoid long-lived degradation.
pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interval: Duration) {
let interval = interval.max(Duration::from_secs(600));
loop {
tokio::time::sleep(interval).await;
let candidate = {
let ws = pool.writers.read().await;
if ws.is_empty() {
None
} else {
let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len();
ws.get(idx).cloned()
}
};
let Some(w) = candidate else {
continue;
};
info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection");
match pool.connect_one(w.addr, rng.as_ref()).await {
Ok(()) => {
// Mark old writer for graceful drain; removal happens when sessions finish.
pool.mark_writer_draining(w.id).await;
}
Err(e) => {
warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed");
}
}
}
}

View File

@@ -0,0 +1,100 @@
use std::time::Duration;
use tracing::{debug, info, warn};
use std::time::SystemTime;
use httpdate;
use crate::error::{ProxyError, Result};
/// Fetch Telegram proxy-secret binary.
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
let cache = cache_path.unwrap_or("proxy-secret");
// 1) Try fresh download first.
match download_proxy_secret().await {
Ok(data) => {
if let Err(e) = tokio::fs::write(cache, &data).await {
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
} else {
debug!(path = cache, len = data.len(), "Cached proxy-secret");
}
return Ok(data);
}
Err(download_err) => {
warn!(error = %download_err, "Proxy-secret download failed, trying cache/file fallback");
// Fall through to cache/file.
}
}
// 2) Fallback to cache/file regardless of age; require len>=32.
match tokio::fs::read(cache).await {
Ok(data) if data.len() >= 32 => {
let age_hours = tokio::fs::metadata(cache)
.await
.ok()
.and_then(|m| m.modified().ok())
.and_then(|m| std::time::SystemTime::now().duration_since(m).ok())
.map(|d| d.as_secs() / 3600);
info!(
path = cache,
len = data.len(),
age_hours,
"Loaded proxy-secret from cache/file after download failure"
);
Ok(data)
}
Ok(data) => Err(ProxyError::Proxy(format!(
"Cached proxy-secret too short: {} bytes (need >= 32)",
data.len()
))),
Err(e) => Err(ProxyError::Proxy(format!(
"Failed to read proxy-secret cache after download failure: {e}"
))),
}
}
pub async fn download_proxy_secret() -> Result<Vec<u8>> {
let resp = reqwest::get("https://core.telegram.org/getProxySecret")
.await
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?;
if !resp.status().is_success() {
return Err(ProxyError::Proxy(format!(
"proxy-secret download HTTP {}",
resp.status()
)));
}
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
if let Ok(date_str) = date.to_str() {
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
server_time.duration_since(SystemTime::now()).map_err(|_| e)
}) {
let skew_secs = skew.as_secs();
if skew_secs > 60 {
warn!(skew_secs, "Time skew >60s detected from proxy-secret Date header");
} else if skew_secs > 30 {
warn!(skew_secs, "Time skew >30s detected from proxy-secret Date header");
}
}
}
}
}
let data = resp
.bytes()
.await
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))?
.to_vec();
if data.len() < 32 {
return Err(ProxyError::Proxy(format!(
"proxy-secret too short: {} bytes (need >= 32)",
data.len()
)));
}
info!(len = data.len(), "Downloaded proxy-secret OK");
Ok(data)
}

View File

@@ -0,0 +1,250 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tracing::{debug, warn};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool;
use super::wire::build_proxy_req_payload;
use rand::seq::SliceRandom;
use super::registry::ConnMeta;
impl MePool {
pub async fn send_proxy_req(
self: &Arc<Self>,
conn_id: u64,
target_dc: i16,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proto_flags: u32,
) -> Result<()> {
let payload = build_proxy_req_payload(
conn_id,
client_addr,
our_addr,
data,
self.proxy_tag.as_deref(),
proto_flags,
);
let meta = ConnMeta {
target_dc,
client_addr,
our_addr,
proto_flags,
};
let mut emergency_attempts = 0;
loop {
if let Some(current) = self.registry.get_writer(conn_id).await {
let send_res = {
if let Ok(mut guard) = current.writer.try_lock() {
let r = guard.send(&payload).await;
drop(guard);
r
} else {
current.writer.lock().await.send(&payload).await
}
};
match send_res {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, writer_id = current.writer_id, "ME write failed");
self.remove_writer_and_close_clients(current.writer_id).await;
continue;
}
}
}
let mut writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
ws.clone()
};
let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
if candidate_indices.is_empty() {
// Emergency connect-on-demand
if emergency_attempts >= 3 {
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
}
emergency_attempts += 1;
for family in self.family_order() {
let map_guard = match family {
IpFamily::V4 => self.proxy_map_v4.read().await,
IpFamily::V6 => self.proxy_map_v6.read().await,
};
if let Some(addrs) = map_guard.get(&(target_dc as i32)) {
let mut shuffled = addrs.clone();
shuffled.shuffle(&mut rand::rng());
drop(map_guard);
for (ip, port) in shuffled {
let addr = SocketAddr::new(ip, port);
if self.connect_one(addr, self.rng.as_ref()).await.is_ok() {
break;
}
}
tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await;
let ws2 = self.writers.read().await;
writers_snapshot = ws2.clone();
drop(ws2);
candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
break;
}
drop(map_guard);
}
if candidate_indices.is_empty() {
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
}
}
candidate_indices.sort_by_key(|idx| {
let w = &writers_snapshot[*idx];
let degraded = w.degraded.load(Ordering::Relaxed);
let draining = w.draining.load(Ordering::Relaxed);
(draining as usize, degraded as usize)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
for offset in 0..candidate_indices.len() {
let idx = candidate_indices[(start + offset) % candidate_indices.len()];
let w = &writers_snapshot[idx];
if w.draining.load(Ordering::Relaxed) {
continue;
}
if let Ok(mut guard) = w.writer.try_lock() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => {
self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME write failed");
self.remove_writer_and_close_clients(w.id).await;
continue;
}
}
}
}
let w = writers_snapshot[candidate_indices[start]].clone();
if w.draining.load(Ordering::Relaxed) {
continue;
}
match w.writer.lock().await.send(&payload).await {
Ok(()) => {
self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME write failed (blocking)");
self.remove_writer_and_close_clients(w.id).await;
}
}
}
}
pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.writer.lock().await.send_and_flush(&p).await {
debug!(error = %e, "ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await;
}
} else {
debug!(conn_id, "ME close skipped (writer missing)");
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub fn connection_count(&self) -> usize {
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
}
pub(super) async fn candidate_indices_for_dc(
&self,
writers: &[super::pool::MeWriter],
target_dc: i16,
) -> Vec<usize> {
let key = target_dc as i32;
let mut preferred = Vec::<SocketAddr>::new();
for family in self.family_order() {
let map_guard = match family {
IpFamily::V4 => self.proxy_map_v4.read().await,
IpFamily::V6 => self.proxy_map_v6.read().await,
};
if let Some(v) = map_guard.get(&key) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map_guard.get(&abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map_guard.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let def = self.default_dc.load(Ordering::Relaxed);
if def != 0 {
if let Some(v) = map_guard.get(&def) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
}
drop(map_guard);
if !preferred.is_empty() && !self.decision.effective_multipath {
break;
}
}
if preferred.is_empty() {
return (0..writers.len())
.filter(|i| !writers[*i].draining.load(Ordering::Relaxed))
.collect();
}
let mut out = Vec::new();
for (idx, w) in writers.iter().enumerate() {
if w.draining.load(Ordering::Relaxed) {
continue;
}
if preferred.iter().any(|p| *p == w.addr) {
out.push(idx);
}
}
if out.is_empty() {
return (0..writers.len())
.filter(|i| !writers[*i].draining.load(Ordering::Relaxed))
.collect();
}
out
}
}

View File

@@ -0,0 +1,118 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use crate::protocol::constants::*;
#[derive(Clone, Copy)]
pub(crate) enum IpMaterial {
V4([u8; 4]),
V6([u8; 16]),
}
pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial {
match addr.ip() {
IpAddr::V4(v4) => IpMaterial::V4(v4.octets()),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpMaterial::V4(v4.octets())
} else {
IpMaterial::V6(v6.octets())
}
}
}
}
fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] {
let mut buf = [0u8; 16];
// Matches tl_store_long(0) + tl_store_int(-0x10000).
buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes());
// Matches tl_store_int(htonl(remote_ip_host_order)).
buf[12..16].copy_from_slice(&ip.octets());
buf
}
fn append_mapped_addr_and_port(buf: &mut Vec<u8>, addr: SocketAddr) {
match addr.ip() {
IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)),
IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()),
}
buf.extend_from_slice(&(addr.port() as u32).to_le_bytes());
}
pub(crate) fn build_proxy_req_payload(
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Vec<u8> {
let mut b = Vec::with_capacity(128 + data.len());
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
b.extend_from_slice(&proto_flags.to_le_bytes());
b.extend_from_slice(&conn_id.to_le_bytes());
append_mapped_addr_and_port(&mut b, client_addr);
append_mapped_addr_and_port(&mut b, our_addr);
if proto_flags & RPC_FLAG_HAS_AD_TAG != 0 {
let extra_start = b.len();
b.extend_from_slice(&0u32.to_le_bytes());
if let Some(tag) = proxy_tag {
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
if tag.len() < 254 {
b.push(tag.len() as u8);
b.extend_from_slice(tag);
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
} else {
b.push(0xfe);
let len_bytes = (tag.len() as u32).to_le_bytes();
b.extend_from_slice(&len_bytes[..3]);
b.extend_from_slice(tag);
let pad = (4 - (tag.len() % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
}
}
let extra_bytes = (b.len() - extra_start - 4) as u32;
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
}
b.extend_from_slice(data);
b
}
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 {
use crate::protocol::constants::ProtoTag;
let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
if has_proxy_tag {
flags |= RPC_FLAG_HAS_AD_TAG;
}
match tag {
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipv4_mapped_encoding() {
let ip = Ipv4Addr::new(149, 154, 175, 50);
let buf = ipv4_to_mapped_v6_c_compat(ip);
assert_eq!(&buf[0..10], &[0u8; 10]);
assert_eq!(&buf[10..12], &[0xff, 0xff]);
assert_eq!(&buf[12..16], &[149, 154, 175, 50]);
}
}

View File

@@ -10,4 +10,5 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*;
pub use socks::*;
pub use upstream::UpstreamManager;
pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager};
pub mod middle_proxy;

View File

@@ -285,12 +285,17 @@ where
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_pool_basic() {
// Start a test server
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = listener.local_addr().unwrap();
// Accept connections in background
@@ -303,7 +308,11 @@ mod tests {
let pool = ConnectionPool::new();
// Get a connection
let conn1 = pool.get(addr).await.unwrap();
let conn1 = match pool.get(addr).await {
Ok(c) => c,
Err(ProxyError::Io(e)) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
// Return it to pool
pool.put(addr, conn1).await;
@@ -335,4 +344,4 @@ mod tests {
assert_eq!(stats.endpoints, 0);
assert_eq!(stats.total_connections, 0);
}
}
}

View File

@@ -205,15 +205,29 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Sock
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_configure_socket() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = listener.local_addr().unwrap();
let stream = TcpStream::connect(addr).await.unwrap();
configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap();
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
if let Err(e) = configure_tcp_socket(&stream, true, Duration::from_secs(30)) {
if e.kind() == ErrorKind::PermissionDenied {
return;
}
panic!("configure_tcp_socket failed: {e}");
}
}
#[test]
@@ -234,4 +248,4 @@ mod tests {
assert!(opts.reuse_port);
assert_eq!(opts.backlog, 1024);
}
}
}

View File

@@ -1,26 +1,154 @@
//! Upstream Management
//! Upstream Management with per-DC latency-weighted selection
//!
//! IPv6/IPv4 connectivity checks with configurable preference.
use std::collections::HashMap;
use std::net::{SocketAddr, IpAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio::time::Instant;
use rand::Rng;
use tracing::{debug, warn, error, info};
use tracing::{debug, warn, info, trace};
use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::create_outgoing_socket_bound;
use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters
const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5;
// ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)]
struct LatencyEma {
value_ms: Option<f64>,
alpha: f64,
}
impl LatencyEma {
const fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha }
}
fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms {
None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
});
}
fn get(&self) -> Option<f64> {
self.value_ms
}
}
// ============= Per-DC IP Preference Tracking =============
/// Tracks which IP version works for each DC
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpPreference {
/// Not yet tested
Unknown,
/// IPv6 works
PreferV6,
/// Only IPv4 works (IPv6 failed)
PreferV4,
/// Both work
BothWork,
/// Both failed
Unavailable,
}
impl Default for IpPreference {
fn default() -> Self {
Self::Unknown
}
}
// ============= Upstream State =============
#[derive(Debug)]
struct UpstreamState {
config: UpstreamConfig,
healthy: bool,
fails: u32,
last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS],
/// Per-DC IP version preference (learned from connectivity tests)
dc_ip_pref: [IpPreference; NUM_DCS],
}
impl UpstreamState {
fn new(config: UpstreamConfig) -> Self {
Self {
config,
healthy: true,
fails: 0,
last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
dc_ip_pref: [IpPreference::Unknown; NUM_DCS],
}
}
/// Map DC index to latency array slot (0..NUM_DCS).
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc == 0 {
return None;
}
if abs_dc >= 1 && abs_dc <= NUM_DCS {
Some(abs_dc - 1)
} else {
// Unknown DC → default cluster (DC 2, index 1)
Some(1)
}
}
/// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() {
return Some(ms);
}
}
let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
if count > 0 { Some(sum / count as f64) } else { None }
}
}
/// Result of a single DC ping
#[derive(Debug, Clone)]
pub struct DcPingResult {
pub dc_idx: usize,
pub dc_addr: SocketAddr,
pub rtt_ms: Option<f64>,
pub error: Option<String>,
}
/// Result of startup ping for one upstream (separate v6/v4 results)
#[derive(Debug, Clone)]
pub struct StartupPingResult {
pub v6_results: Vec<DcPingResult>,
pub v4_results: Vec<DcPingResult>,
pub upstream_name: String,
/// True if both IPv6 and IPv4 have at least one working DC
pub both_available: bool,
}
// ============= Upstream Manager =============
#[derive(Clone)]
pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
@@ -30,230 +158,573 @@ impl UpstreamManager {
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
let states = configs.into_iter()
.filter(|c| c.enabled)
.map(|c| UpstreamState {
config: c,
healthy: true, // Optimistic start
fails: 0,
last_check: std::time::Instant::now(),
})
.map(UpstreamState::new)
.collect();
Self {
upstreams: Arc::new(RwLock::new(states)),
}
}
/// Select an upstream using Weighted Round Robin (simplified)
async fn select_upstream(&self) -> Option<usize> {
/// Select upstream using latency-weighted random selection.
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await;
if upstreams.is_empty() {
return None;
}
let healthy_indices: Vec<usize> = upstreams.iter()
let healthy: Vec<usize> = upstreams.iter()
.enumerate()
.filter(|(_, u)| u.healthy)
.map(|(i, _)| i)
.collect();
if healthy_indices.is_empty() {
// If all unhealthy, try any random one
return Some(rand::thread_rng().gen_range(0..upstreams.len()));
if healthy.is_empty() {
return Some(rand::rng().gen_range(0..upstreams.len()));
}
// Weighted selection
let total_weight: u32 = healthy_indices.iter()
.map(|&i| upstreams[i].config.weight as u32)
.sum();
if total_weight == 0 {
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
if healthy.len() == 1 {
return Some(healthy[0]);
}
let mut choice = rand::thread_rng().gen_range(0..total_weight);
for &idx in &healthy_indices {
let weight = upstreams[idx].config.weight as u32;
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx)
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
.unwrap_or(1.0);
(i, base * latency_factor)
}).collect();
let total: f64 = weights.iter().map(|(_, w)| w).sum();
if total <= 0.0 {
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
}
let mut choice: f64 = rand::rng().gen_range(0.0..total);
for &(idx, weight) in &weights {
if choice < weight {
trace!(
upstream = idx,
dc = ?dc_idx,
weight = format!("{:.2}", weight),
total = format!("{:.2}", total),
"Upstream selected"
);
return Some(idx);
}
choice -= weight;
}
Some(healthy_indices[0])
Some(healthy[0])
}
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> {
let idx = self.select_upstream().await
/// Connect to target through a selected upstream.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
let upstream = {
let guard = self.upstreams.read().await;
guard[idx].config.clone()
};
let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => {
// Mark success
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
if !u.healthy {
debug!("Upstream recovered: {:?}", u.config);
debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
}
u.healthy = true;
u.fails = 0;
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms);
}
}
Ok(stream)
},
Err(e) => {
// Mark failure
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
u.fails += 1;
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
warn!(fails = u.fails, "Upstream failed: {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream disabled due to failures: {:?}", u.config);
warn!("Upstream marked unhealthy");
}
}
Err(e)
}
}
}
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
match &config.upstream_type {
UpstreamType::Direct { interface } => {
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(target, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?;
match socket.connect(&target.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
Ok(stream)
},
UpstreamType::Socks4 { address, interface, user_id } => {
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
connect_socks4(&mut stream, target, user_id.as_deref()).await?;
Ok(stream)
},
UpstreamType::Socks5 { address, interface, username, password } => {
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) {
Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)),
}
let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?;
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
Ok(stream)
},
}
}
/// Background task to check health
pub async fn run_health_checks(&self) {
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2)
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
// ============= Startup Ping (test both IPv6 and IPv4) =============
/// Ping all Telegram DCs through all upstreams.
/// Tests BOTH IPv6 and IPv4, returns separate results for each.
pub async fn ping_all_dcs(
&self,
prefer_ipv6: bool,
dc_overrides: &HashMap<String, Vec<String>>,
ipv4_enabled: bool,
ipv6_enabled: bool,
) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
guard.iter().enumerate()
.map(|(i, u)| (i, u.config.clone()))
.collect()
};
let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => {
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
}
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut v6_results = Vec::with_capacity(NUM_DCS);
if ipv6_enabled {
for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v6)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v6_results.push(ping_result);
}
} else {
for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
v6_results.push(DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: SocketAddr::new(dc_v6, TG_DATACENTER_PORT),
rtt_ms: None,
error: Some("ipv6 disabled".to_string()),
});
}
}
let mut v4_results = Vec::with_capacity(NUM_DCS);
if ipv4_enabled {
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v4)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v4_results.push(ping_result);
}
} else {
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
v4_results.push(DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: SocketAddr::new(dc_v4, TG_DATACENTER_PORT),
rtt_ms: None,
error: Some("ipv4 disabled".to_string()),
});
}
}
// === Ping DC overrides (v4/v6) ===
for (dc_key, addrs) in dc_overrides {
let dc_num: i16 = match dc_key.parse::<i16>() {
Ok(v) if v > 0 => v,
Err(_) => {
warn!(dc = %dc_key, "Invalid dc_overrides key, skipping");
continue;
},
_ => continue,
};
let dc_idx = dc_num as usize;
for addr_str in addrs {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
let is_v6 = addr.is_ipv6();
if (is_v6 && !ipv6_enabled) || (!is_v6 && !ipv4_enabled) {
continue;
}
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => DcPingResult {
dc_idx,
dc_addr: addr,
rtt_ms: Some(rtt_ms),
error: None,
},
Ok(Err(e)) => DcPingResult {
dc_idx,
dc_addr: addr,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx,
dc_addr: addr,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
if is_v6 {
v6_results.push(ping_result);
} else {
v4_results.push(ping_result);
}
}
Err(_) => warn!(dc = %dc_idx, addr = %addr_str, "Invalid dc_overrides address, skipping"),
}
}
}
// Check if both IP versions have at least one working DC
let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some());
let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some());
let both_available = v6_has_working && v4_has_working;
// Update IP preference for each DC
{
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
for dc_zero_idx in 0..NUM_DCS {
let v6_ok = v6_results[dc_zero_idx].rtt_ms.is_some();
let v4_ok = v4_results[dc_zero_idx].rtt_ms.is_some();
u.dc_ip_pref[dc_zero_idx] = match (v6_ok, v4_ok) {
(true, true) => IpPreference::BothWork,
(true, false) => IpPreference::PreferV6,
(false, true) => IpPreference::PreferV4,
(false, false) => IpPreference::Unavailable,
};
}
}
}
all_results.push(StartupPingResult {
v6_results,
v4_results,
upstream_name,
both_available,
});
}
all_results
}
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0)
}
// ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval.
/// Uses preferred IP version based on config.
pub async fn run_health_checks(&self, prefer_ipv6: bool, ipv4_enabled: bool, ipv6_enabled: bool) {
let mut dc_rotation = 0usize;
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % NUM_DCS;
dc_rotation += 1;
let primary_v6 = SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT);
let primary_v4 = SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT);
let dc_addr = if prefer_ipv6 && ipv6_enabled {
primary_v6
} else if ipv4_enabled {
primary_v4
} else if ipv6_enabled {
primary_v6
} else {
continue;
};
let fallback_addr = if dc_addr.is_ipv6() && ipv4_enabled {
Some(primary_v4)
} else if dc_addr.is_ipv4() && ipv6_enabled {
Some(primary_v6)
} else {
None
};
let count = self.upstreams.read().await.len();
for i in 0..count {
let config = {
let guard = self.upstreams.read().await;
guard[i].config.clone()
};
let start = Instant::now();
let result = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, check_target)
self.connect_via_upstream(&config, dc_addr)
).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result {
Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
debug!("Upstream recovered: {:?}", u.config);
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered"
);
}
u.healthy = true;
u.fails = 0;
u.last_check = std::time::Instant::now();
}
Ok(Err(e)) => {
debug!("Health check failed for {:?}: {}", u.config, e);
// Don't mark unhealthy immediately in background check
}
Err(_) => {
debug!("Health check timeout for {:?}", u.config);
Ok(Err(_)) | Err(_) => {
// Try fallback
debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
if let Some(fallback_addr) = fallback_addr {
let start2 = Instant::now();
let result2 = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, fallback_addr)
).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result2 {
Ok(Ok(_stream)) => {
let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
}
u.last_check = std::time::Instant::now();
continue;
}
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.fails += 1;
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (no fallback family)");
}
u.last_check = std::time::Instant::now();
}
}
u.last_check = std::time::Instant::now();
}
}
}
}
/// Get the preferred IP for a DC (for use by other components)
pub async fn get_dc_ip_preference(&self, dc_idx: i16) -> Option<IpPreference> {
let guard = self.upstreams.read().await;
if guard.is_empty() {
return None;
}
UpstreamState::dc_array_idx(dc_idx)
.map(|idx| guard[0].dc_ip_pref[idx])
}
/// Get preferred DC address based on config preference
pub async fn get_dc_addr(&self, dc_idx: i16, prefer_ipv6: bool) -> Option<SocketAddr> {
let arr_idx = UpstreamState::dc_array_idx(dc_idx)?;
let ip = if prefer_ipv6 {
TG_DATACENTERS_V6[arr_idx]
} else {
TG_DATACENTERS_V4[arr_idx]
};
Some(SocketAddr::new(ip, TG_DATACENTER_PORT))
}
}

204
tools/dc.py Normal file
View File

@@ -0,0 +1,204 @@
"""Telegram datacenter server checker."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from itertools import groupby
from operator import attrgetter
from pathlib import Path
from typing import TYPE_CHECKING
from telethon import TelegramClient
from telethon.tl.functions.help import GetConfigRequest
if TYPE_CHECKING:
from telethon.tl.types import DcOption
API_ID: int = 123456
API_HASH: str = ""
SESSION_NAME: str = "session"
OUTPUT_FILE: Path = Path("telegram_servers.txt")
_CONSOLE_FLAG_MAP: dict[str, str] = {
"IPv6": "IPv6",
"MEDIA-ONLY": "🎬 MEDIA-ONLY",
"CDN": "📦 CDN",
"TCPO": "🔒 TCPO",
"STATIC": "📌 STATIC",
}
@dataclass(frozen=True, slots=True)
class DCServer:
"""Typed representation of a Telegram DC server.
Attributes:
dc_id: Datacenter identifier.
ip: Server IP address.
port: Server port.
flags: Active flag labels (plain, without emoji).
"""
dc_id: int
ip: str
port: int
flags: frozenset[str] = field(default_factory=frozenset)
@classmethod
def from_option(cls, dc: DcOption) -> DCServer:
"""Create from a Telethon DcOption.
Args:
dc: Raw DcOption object.
Returns:
Parsed DCServer instance.
"""
checks: dict[str, bool] = {
"IPv6": dc.ipv6,
"MEDIA-ONLY": dc.media_only,
"CDN": dc.cdn,
"TCPO": dc.tcpo_only,
"STATIC": dc.static,
}
return cls(
dc_id=dc.id,
ip=dc.ip_address,
port=dc.port,
flags=frozenset(k for k, v in checks.items() if v),
)
def flags_display(self, *, emoji: bool = False) -> str:
"""Formatted flags string.
Args:
emoji: Whether to include emoji prefixes.
Returns:
Bracketed flags or '[STANDARD]'.
"""
if not self.flags:
return "[STANDARD]"
labels = sorted(
_CONSOLE_FLAG_MAP[f] if emoji else f for f in self.flags
)
return f"[{', '.join(labels)}]"
class TelegramDCChecker:
"""Fetches and displays Telegram DC configuration.
Attributes:
_client: Telethon client instance.
_servers: Parsed server list.
"""
def __init__(self) -> None:
"""Initialize the checker."""
self._client = TelegramClient(SESSION_NAME, API_ID, API_HASH)
self._servers: list[DCServer] = []
async def run(self) -> None:
"""Connect, fetch config, display and save results."""
print("🔄 Подключаемся к Telegram...") # noqa: T201
try:
await self._client.start()
print("✅ Подключение установлено!\n") # noqa: T201
print("📡 Запрашиваем конфигурацию серверов...") # noqa: T201
config = await self._client(GetConfigRequest())
self._servers = [DCServer.from_option(dc) for dc in config.dc_options]
self._print(config)
self._save(config)
finally:
await self._client.disconnect()
print("\n👋 Отключились от Telegram") # noqa: T201
def _grouped(self) -> dict[int, list[DCServer]]:
"""Group servers by DC ID.
Returns:
Ordered mapping of DC ID to servers.
"""
ordered = sorted(self._servers, key=attrgetter("dc_id"))
return {k: list(g) for k, g in groupby(ordered, key=attrgetter("dc_id"))}
def _print(self, config: object) -> None:
"""Print results to stdout in original format.
Args:
config: Raw Telegram config.
"""
sep = "=" * 80
dash = "-" * 80
total = len(self._servers)
print(f"📊 Получено серверов: {total}\n") # noqa: T201
print(sep) # noqa: T201
for dc_id, servers in self._grouped().items():
print(f"\n🌐 DATACENTER {dc_id} ({len(servers)} серверов)") # noqa: T201
print(dash) # noqa: T201
for s in servers:
print(f" {s.ip:45}:{s.port:5} {s.flags_display(emoji=True)}") # noqa: T201
ipv4 = total - self._flag_count("IPv6")
print(f"\n{sep}") # noqa: T201
print("📈 СТАТИСТИКА:") # noqa: T201
print(sep) # noqa: T201
print(f" Всего серверов: {total}") # noqa: T201
print(f" IPv4 серверы: {ipv4}") # noqa: T201
print(f" IPv6 серверы: {self._flag_count('IPv6')}") # noqa: T201
print(f" Media-only: {self._flag_count('MEDIA-ONLY')}") # noqa: T201
print(f" CDN серверы: {self._flag_count('CDN')}") # noqa: T201
print(f" TCPO-only: {self._flag_count('TCPO')}") # noqa: T201
print(f" Static: {self._flag_count('STATIC')}") # noqa: T201
print(f"\n{sep}") # noqa: T201
print(" ДОПОЛНИТЕЛЬНАЯ ИНФОРМАЦИЯ:") # noqa: T201
print(sep) # noqa: T201
print(f" Дата конфигурации: {config.date}") # noqa: T201 # type: ignore[attr-defined]
print(f" Expires: {config.expires}") # noqa: T201 # type: ignore[attr-defined]
print(f" Test mode: {config.test_mode}") # noqa: T201 # type: ignore[attr-defined]
print(f" This DC: {config.this_dc}") # noqa: T201 # type: ignore[attr-defined]
def _flag_count(self, flag: str) -> int:
"""Count servers with a given flag.
Args:
flag: Flag name.
Returns:
Count of matching servers.
"""
return sum(1 for s in self._servers if flag in s.flags)
def _save(self, config: object) -> None:
"""Save results to file in original format.
Args:
config: Raw Telegram config.
"""
parts: list[str] = []
parts.append("TELEGRAM DATACENTER SERVERS\n")
parts.append("=" * 80 + "\n\n")
for dc_id, servers in self._grouped().items():
parts.append(f"\nDATACENTER {dc_id} ({len(servers)} servers)\n")
parts.append("-" * 80 + "\n")
for s in servers:
parts.append(f" {s.ip}:{s.port} {s.flags_display(emoji=False)}\n")
parts.append(f"\n\nTotal servers: {len(self._servers)}\n")
parts.append(f"Generated: {config.date}\n") # type: ignore[attr-defined]
OUTPUT_FILE.write_text("".join(parts), encoding="utf-8")
print(f"\n💾 Сохраняем результаты в файл {OUTPUT_FILE}...") # noqa: T201
print(f"✅ Результаты сохранены в {OUTPUT_FILE}") # noqa: T201
if __name__ == "__main__":
asyncio.run(TelegramDCChecker().run())

View File

@@ -0,0 +1,804 @@
{
"apiVersion": "dashboard.grafana.app/v1beta1",
"kind": "Dashboard",
"metadata": {
"annotations": {
"grafana.app/folder": "afd9kjusw2jnkb",
"grafana.app/saved-from-ui": "Grafana v12.4.0-21693836646 (f059795f04)"
},
"labels": {},
"name": "pi9trh5",
"namespace": "default"
},
"spec": {
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"links": [],
"panels": [
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 0
},
"id": 5,
"panels": [],
"title": "Common",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "red",
"value": 0
},
{
"color": "green",
"value": 300
}
]
},
"unit": "s"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 0,
"y": 1
},
"id": 1,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "max(telemt_uptime_seconds) by (service)",
"format": "time_series",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "uptime",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 6,
"y": 1
},
"id": 2,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "max(telemt_connections_total) by (service)",
"format": "time_series",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "connections_total",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 12,
"y": 1
},
"id": 3,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "max(telemt_connections_bad_total) by (service)",
"format": "time_series",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "connections_bad",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 18,
"y": 1
},
"id": 4,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"percentChangeColorMode": "standard",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "max(telemt_handshake_timeouts_total) by (service)",
"format": "time_series",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "handshake_timeouts",
"type": "stat"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 9
},
"id": 6,
"panels": [],
"repeat": "user",
"title": "$user",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"showValues": false,
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 10
},
"id": 7,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"editorMode": "code",
"expr": "sum(telemt_user_connections_total{user=\"$user\"}) by (user)",
"format": "time_series",
"legendFormat": "{{ user }}",
"range": true,
"refId": "A"
}
],
"title": "user_connections",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"showValues": false,
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 10
},
"id": 8,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"editorMode": "code",
"expr": "sum(telemt_user_connections_current{user=\"$user\"}) by (user)",
"format": "time_series",
"legendFormat": "{{ user }}",
"range": true,
"refId": "A"
}
],
"title": "user_connections_current",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"showValues": false,
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "binBps"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 18
},
"id": 9,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"editorMode": "code",
"expr": "- sum(rate(telemt_user_octets_from_client{user=\"$user\"}[$__rate_interval])) by (user)",
"format": "time_series",
"legendFormat": "{{ user }} TX",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "sum(rate(telemt_user_octets_to_client{user=\"$user\"}[$__rate_interval])) by (user)",
"format": "time_series",
"legendFormat": "{{ user }} RX",
"range": true,
"refId": "B"
}
],
"title": "user_octets",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"barWidthFactor": 0.6,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"showValues": false,
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": 0
},
{
"color": "red",
"value": 80
}
]
},
"unit": "pps"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 18
},
"id": 10,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"hideZeros": false,
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "12.4.0-21693836646",
"targets": [
{
"editorMode": "code",
"expr": "- sum(rate(telemt_user_msgs_from_client{user=\"$user\"}[$__rate_interval])) by (user)",
"format": "time_series",
"legendFormat": "{{ user }} TX",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "sum(rate(telemt_user_msgs_to_client{user=\"$user\"}[$__rate_interval])) by (user)",
"format": "time_series",
"legendFormat": "{{ user }} RX",
"range": true,
"refId": "B"
}
],
"title": "user_msgs",
"type": "timeseries"
}
],
"preload": false,
"schemaVersion": 42,
"tags": [],
"templating": {
"list": [
{
"current": {
"text": "docker",
"value": "docker"
},
"datasource": {
"type": "prometheus",
"uid": "${datasource}"
},
"definition": "label_values(telemt_user_connections_total,user)",
"hide": 2,
"multi": true,
"name": "user",
"options": [],
"query": {
"qryType": 1,
"query": "label_values(telemt_user_connections_total,user)",
"refId": "VariableQueryEditor-VariableQuery"
},
"refresh": 1,
"regex": "",
"regexApplyTo": "value",
"sort": 1,
"type": "query"
},
{
"current": {
"text": "VM long-term",
"value": "P7D3016A027385E71"
},
"name": "datasource",
"options": [],
"query": "prometheus",
"refresh": 1,
"regex": "",
"type": "datasource"
}
]
},
"time": {
"from": "now-6h",
"to": "now"
},
"timepicker": {},
"timezone": "browser",
"title": "Telemt MtProto proxy",
"weekStart": ""
}
}