Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fa6867056 | ||
|
|
54ea6efdd0 | ||
|
|
27ac32a901 | ||
|
|
829f53c123 | ||
|
|
43eae6127d | ||
|
|
a03212c8cc | ||
|
|
2613969a7c | ||
|
|
be1b2db867 | ||
|
|
8fbee8701b | ||
|
|
952d160870 | ||
|
|
91ae6becde | ||
|
|
e1f576e4fe | ||
|
|
a7556cabdc | ||
|
|
b2e8d16bb1 | ||
|
|
d95e762812 | ||
|
|
384f927fc3 | ||
|
|
1b7c09ae18 | ||
|
|
85cb4092d5 | ||
|
|
5016160ac3 | ||
|
|
4f007f3128 | ||
|
|
7746a1177c | ||
|
|
2bb2a2983f |
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -10,8 +10,8 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build-and-test:
|
||||
name: Build & Test
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
||||
175
README.md
175
README.md
@@ -5,13 +5,21 @@
|
||||
# GOTO
|
||||
- [Features](#features)
|
||||
- [Quick Start Guide](#quick-start-guide)
|
||||
- [Build](#build)
|
||||
- [How to use?](#how-to-use)
|
||||
- [Systemd Method](#telemt-via-systemd)
|
||||
- [Configuration](#configuration)
|
||||
- [Minimal Configuration](#minimal-configuration-for-first-start)
|
||||
- [Advanced](#advanced)
|
||||
- [Adtag](#adtag)
|
||||
- [Listening and Announce IPs](#listening-and-announce-ips)
|
||||
- [Upstream Manager](#upstream-manager)
|
||||
- [IP](#bind-on-ip)
|
||||
- [SOCKS](#socks45-as-upstream)
|
||||
- [FAQ](#faq)
|
||||
- [Telegram Calls](#telegram-calls-via-mtproxy)
|
||||
- [DPI](#how-does-dpi-see-mtproxy-tls)
|
||||
- [Whitelist on Network Level](#whitelist-on-ip)
|
||||
- [Build](#build)
|
||||
- [Why Rust?](#why-rust)
|
||||
|
||||
## Features
|
||||
@@ -27,25 +35,27 @@
|
||||
- Extensive logging via `trace` and `debug` with `RUST_LOG` method
|
||||
|
||||
## Quick Start Guide
|
||||
|
||||
### Build
|
||||
**This software is designed for Debian-based OS: in addition to Debian, these are Ubuntu, Mint, Kali, MX and many other Linux**
|
||||
1. Download release
|
||||
```bash
|
||||
# Cloning repo
|
||||
git clone https://github.com/telemt/telemt
|
||||
# Changing Directory to telemt
|
||||
cd telemt
|
||||
# Starting Release Build
|
||||
cargo build --release
|
||||
# Move to /bin
|
||||
mv ./target/release/telemt /bin
|
||||
# Make executable
|
||||
chmod +x /bin/telemt
|
||||
# Lets go!
|
||||
telemt config.toml
|
||||
wget https://github.com/telemt/telemt/releases/latest/download/telemt
|
||||
```
|
||||
2. Move to Bin Folder
|
||||
```bash
|
||||
mv telemt /bin
|
||||
```
|
||||
4. Make Executable
|
||||
```bash
|
||||
chmod +x /bin/telemt
|
||||
```
|
||||
5. Go to [How to use?](#how-to-use) section for for further steps
|
||||
|
||||
## How to use?
|
||||
### Telemt via Systemd
|
||||
**This instruction "assume" that you:**
|
||||
- logged in as root or executed `su -` / `sudo su`
|
||||
- you already have an assembled and executable `telemt` in /bin folder as a result of the [Quick Start Guide](#quick-start-guide) or [Build](#build)
|
||||
|
||||
**0. Check port and generate secrets**
|
||||
|
||||
The port you have selected for use should be MISSING from the list, when:
|
||||
@@ -57,6 +67,14 @@ Generate 16 bytes/32 characters HEX with OpenSSL or another way:
|
||||
```bash
|
||||
openssl rand -hex 16
|
||||
```
|
||||
OR
|
||||
```bash
|
||||
xxd -l 16 -p /dev/urandom
|
||||
```
|
||||
OR
|
||||
```bash
|
||||
python3 -c 'import os; print(os.urandom(16).hex())'
|
||||
```
|
||||
|
||||
**1. Place your config to /etc/telemt.toml**
|
||||
|
||||
@@ -64,28 +82,8 @@ Open nano
|
||||
```bash
|
||||
nano /etc/telemt.toml
|
||||
```
|
||||
```bash
|
||||
port = 443 # Listening port
|
||||
paste your config from [Configuration](#configuration) section
|
||||
|
||||
[users]
|
||||
hello = "00000000000000000000000000000000" # Replace the secret with one generated before
|
||||
|
||||
[modes]
|
||||
classic = false # Plain obfuscated mode
|
||||
secure = false # dd-prefix mode
|
||||
tls = true # Fake TLS - ee-prefix
|
||||
|
||||
tls_domain = "petrovich.ru" # Domain for ee-secret and masking
|
||||
mask = true # Enable masking of bad traffic
|
||||
mask_host = "petrovich.ru" # Optional override for mask destination
|
||||
mask_port = 443 # Port for masking
|
||||
|
||||
prefer_ipv6 = false # Try IPv6 DCs first if true
|
||||
fast_mode = true # Use "fast" obfuscation variant
|
||||
|
||||
client_keepalive = 600 # Seconds
|
||||
client_ack_timeout = 300 # Seconds
|
||||
```
|
||||
then Ctrl+X -> Y -> Enter to save
|
||||
|
||||
**2. Create service on /etc/systemd/system/telemt.service**
|
||||
@@ -117,9 +115,79 @@ then Ctrl+X -> Y -> Enter to save
|
||||
|
||||
**5.** In Shell type `systemctl enable telemt` - then telemt will start with system startup, after the network is up
|
||||
|
||||
## Configuration
|
||||
### Minimal Configuration for First Start
|
||||
```toml
|
||||
port = 443 # Listening port
|
||||
show_links = ["tele", "hello"] # Specify users, for whom will be displayed the links
|
||||
|
||||
[users]
|
||||
tele = "00000000000000000000000000000000" # Replace the secret with one generated before
|
||||
hello = "00000000000000000000000000000000" # Replace the secret with one generated before
|
||||
|
||||
[modes]
|
||||
classic = false # Plain obfuscated mode
|
||||
secure = false # dd-prefix mode
|
||||
tls = true # Fake TLS - ee-prefix
|
||||
|
||||
tls_domain = "petrovich.ru" # Domain for ee-secret and masking
|
||||
mask = true # Enable masking of bad traffic
|
||||
mask_host = "petrovich.ru" # Optional override for mask destination
|
||||
mask_port = 443 # Port for masking
|
||||
|
||||
prefer_ipv6 = false # Try IPv6 DCs first if true
|
||||
fast_mode = true # Use "fast" obfuscation variant
|
||||
|
||||
client_keepalive = 600 # Seconds
|
||||
client_ack_timeout = 300 # Seconds
|
||||
```
|
||||
### Advanced
|
||||
#### Adtag
|
||||
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it
|
||||
```toml
|
||||
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
|
||||
```
|
||||
#### Listening and Announce IPs
|
||||
To specify listening address and/or address in links, add to the end of config.toml:
|
||||
```toml
|
||||
[[listeners]]
|
||||
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
|
||||
announce_ip = "1.2.3.4" # IP in links; comment with # if not used
|
||||
```
|
||||
#### Upstream Manager
|
||||
To specify upstream, add to the end of config.toml:
|
||||
##### Bind on IP
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "direct"
|
||||
weight = 1
|
||||
enabled = true
|
||||
interface = "192.168.1.100" # Change to your outgoing IP
|
||||
```
|
||||
##### SOCKS4/5 as Upstream
|
||||
- Without Auth:
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||
weight = 1 # Set Weight for Scenarios
|
||||
enabled = true
|
||||
```
|
||||
|
||||
- With Auth:
|
||||
```toml
|
||||
[[upstreams]]
|
||||
type = "socks5" # Specify SOCKS4 or SOCKS5
|
||||
address = "1.2.3.4:1234" # SOCKS-server Address
|
||||
username = "user" # Username for Auth on SOCKS-server
|
||||
password = "pass" # Password for Auth on SOCKS-server
|
||||
weight = 1 # Set Weight for Scenarios
|
||||
enabled = true
|
||||
```
|
||||
|
||||
## FAQ
|
||||
### Telegram Calls via MTProxy
|
||||
- Telegram architecture does **NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
|
||||
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
|
||||
### How does DPI see MTProxy TLS?
|
||||
- DPI sees MTProxy in Fake TLS (ee) mode as TLS 1.3
|
||||
- the SNI you specify sends both the client and the server;
|
||||
@@ -127,11 +195,34 @@ then Ctrl+X -> Y -> Enter to save
|
||||
- high entropy, which is normal for AES-encrypted traffic;
|
||||
### Whitelist on IP
|
||||
- MTProxy cannot work when there is:
|
||||
- no IP connectivity to the target host
|
||||
- no IP connectivity to the target host: Russian Whitelist on Mobile Networks - "Белый список"
|
||||
- OR all TCP traffic is blocked
|
||||
- OR all TLS traffic is blocked,
|
||||
- OR high entropy/encrypted traffic is blocked: content filters at universities and critical infrastructure
|
||||
- OR all TLS traffic is blocked
|
||||
- OR specified port is blocked: use 443 to make it "like real"
|
||||
- OR provided SNI is blocked: use "officially approved"/innocuous name
|
||||
- like most protocols on the Internet;
|
||||
- this situation is observed in China behind the Great Chinese Firewall and in Russia on mobile networks
|
||||
- these situations are observed:
|
||||
- in China behind the Great Firewall
|
||||
- in Russia on mobile networks, less in wired networks
|
||||
- in Iran during "activity"
|
||||
|
||||
|
||||
## Build
|
||||
```bash
|
||||
# Cloning repo
|
||||
git clone https://github.com/telemt/telemt
|
||||
# Changing Directory to telemt
|
||||
cd telemt
|
||||
# Starting Release Build
|
||||
cargo build --release
|
||||
# Move to /bin
|
||||
mv ./target/release/telemt /bin
|
||||
# Make executable
|
||||
chmod +x /bin/telemt
|
||||
# Lets go!
|
||||
telemt config.toml
|
||||
```
|
||||
|
||||
## Why Rust?
|
||||
- Long-running reliability and idempotent behavior
|
||||
@@ -140,6 +231,10 @@ then Ctrl+X -> Y -> Enter to save
|
||||
- Memory safety and reduced attack surface
|
||||
- Tokio's asynchronous architecture
|
||||
|
||||
## Issues
|
||||
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
|
||||
- ⌛ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
|
||||
|
||||
## Roadmap
|
||||
- Public IP in links
|
||||
- Config Reload-on-fly
|
||||
|
||||
@@ -18,6 +18,7 @@ pub struct ProxyModes {
|
||||
}
|
||||
|
||||
fn default_true() -> bool { true }
|
||||
fn default_weight() -> u16 { 1 }
|
||||
|
||||
impl Default for ProxyModes {
|
||||
fn default() -> Self {
|
||||
@@ -25,6 +26,48 @@ impl Default for ProxyModes {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum UpstreamType {
|
||||
Direct {
|
||||
#[serde(default)]
|
||||
interface: Option<String>, // Bind to specific IP/Interface
|
||||
},
|
||||
Socks4 {
|
||||
address: String, // IP:Port of SOCKS server
|
||||
#[serde(default)]
|
||||
interface: Option<String>, // Bind to specific IP/Interface for connection to SOCKS
|
||||
#[serde(default)]
|
||||
user_id: Option<String>,
|
||||
},
|
||||
Socks5 {
|
||||
address: String,
|
||||
#[serde(default)]
|
||||
interface: Option<String>,
|
||||
#[serde(default)]
|
||||
username: Option<String>,
|
||||
#[serde(default)]
|
||||
password: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpstreamConfig {
|
||||
#[serde(flatten)]
|
||||
pub upstream_type: UpstreamType,
|
||||
#[serde(default = "default_weight")]
|
||||
pub weight: u16,
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ListenerConfig {
|
||||
pub ip: IpAddr,
|
||||
#[serde(default)]
|
||||
pub announce_ip: Option<IpAddr>, // IP to show in tg:// links
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyConfig {
|
||||
#[serde(default = "default_port")]
|
||||
@@ -104,15 +147,28 @@ pub struct ProxyConfig {
|
||||
|
||||
#[serde(default = "default_fake_cert_len")]
|
||||
pub fake_cert_len: usize,
|
||||
|
||||
// New fields
|
||||
#[serde(default)]
|
||||
pub upstreams: Vec<UpstreamConfig>,
|
||||
|
||||
#[serde(default)]
|
||||
pub listeners: Vec<ListenerConfig>,
|
||||
|
||||
#[serde(default)]
|
||||
pub show_link: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_port() -> u16 { 443 }
|
||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||
fn default_mask_port() -> u16 { 443 }
|
||||
fn default_replay_check_len() -> usize { 65536 }
|
||||
fn default_handshake_timeout() -> u64 { 10 }
|
||||
// CHANGED: Increased handshake timeout for bad mobile networks
|
||||
fn default_handshake_timeout() -> u64 { 15 }
|
||||
fn default_connect_timeout() -> u64 { 10 }
|
||||
fn default_keepalive() -> u64 { 600 }
|
||||
// CHANGED: Reduced keepalive from 600s to 60s.
|
||||
// Mobile NATs often drop idle connections after 60-120s.
|
||||
fn default_keepalive() -> u64 { 60 }
|
||||
fn default_ack_timeout() -> u64 { 300 }
|
||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
||||
fn default_fake_cert_len() -> usize { 2048 }
|
||||
@@ -156,6 +212,9 @@ impl Default for ProxyConfig {
|
||||
metrics_port: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
fake_cert_len: default_fake_cert_len(),
|
||||
upstreams: Vec::new(),
|
||||
listeners: Vec::new(),
|
||||
show_link: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,6 +246,33 @@ impl ProxyConfig {
|
||||
use rand::Rng;
|
||||
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
|
||||
|
||||
// Migration: Populate listeners if empty
|
||||
if config.listeners.is_empty() {
|
||||
if let Ok(ipv4) = config.listen_addr_ipv4.parse::<IpAddr>() {
|
||||
config.listeners.push(ListenerConfig {
|
||||
ip: ipv4,
|
||||
announce_ip: None,
|
||||
});
|
||||
}
|
||||
if let Some(ipv6_str) = &config.listen_addr_ipv6 {
|
||||
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
|
||||
config.listeners.push(ListenerConfig {
|
||||
ip: ipv6,
|
||||
announce_ip: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Migration: Populate upstreams if empty (Default Direct)
|
||||
if config.upstreams.is_empty() {
|
||||
config.upstreams.push(UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct { interface: None },
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -202,26 +288,3 @@ impl ProxyConfig {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ProxyConfig::default();
|
||||
assert_eq!(config.port, 443);
|
||||
assert!(config.modes.tls);
|
||||
assert_eq!(config.client_keepalive, 600);
|
||||
assert_eq!(config.client_ack_timeout, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_validate() {
|
||||
let mut config = ProxyConfig::default();
|
||||
assert!(config.validate().is_ok());
|
||||
|
||||
config.users.clear();
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
}
|
||||
@@ -235,6 +235,9 @@ pub enum ProxyError {
|
||||
#[error("Invalid proxy protocol header")]
|
||||
InvalidProxyProtocol,
|
||||
|
||||
#[error("Proxy error: {0}")]
|
||||
Proxy(String),
|
||||
|
||||
// ============= Config Errors =============
|
||||
|
||||
#[error("Config error: {0}")]
|
||||
|
||||
302
src/main.rs
302
src/main.rs
@@ -1,158 +1,196 @@
|
||||
//! Telemt - MTProxy on Rust
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tracing::{info, error, Level};
|
||||
use tracing_subscriber::{FmtSubscriber, EnvFilter};
|
||||
use tracing::{info, error, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
mod error;
|
||||
mod config;
|
||||
mod crypto;
|
||||
mod error;
|
||||
mod protocol;
|
||||
mod proxy;
|
||||
mod stats;
|
||||
mod stream;
|
||||
mod transport;
|
||||
mod proxy;
|
||||
mod config;
|
||||
mod stats;
|
||||
mod util;
|
||||
|
||||
use config::ProxyConfig;
|
||||
use stats::{Stats, ReplayChecker};
|
||||
use transport::ConnectionPool;
|
||||
use proxy::ClientHandler;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::proxy::ClientHandler;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||
use crate::util::ip::detect_ip;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging with env filter
|
||||
// Use RUST_LOG=debug or RUST_LOG=trace for more details
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging
|
||||
fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
|
||||
.init();
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(false)
|
||||
.with_file(false)
|
||||
.with_line_number(false)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber)?;
|
||||
|
||||
// Load configuration
|
||||
let config_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| "config.toml".to_string());
|
||||
|
||||
info!("Loading configuration from {}", config_path);
|
||||
|
||||
let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| {
|
||||
error!("Failed to load config: {}", e);
|
||||
info!("Using default configuration");
|
||||
ProxyConfig::default()
|
||||
});
|
||||
|
||||
if let Err(e) = config.validate() {
|
||||
error!("Invalid configuration: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
info!("Starting MTProto Proxy on port {}", config.port);
|
||||
info!("Fast mode: {}", config.fast_mode);
|
||||
info!("Modes: classic={}, secure={}, tls={}",
|
||||
config.modes.classic, config.modes.secure, config.modes.tls);
|
||||
|
||||
// Initialize components
|
||||
let stats = Arc::new(Stats::new());
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||
let pool = Arc::new(ConnectionPool::new());
|
||||
|
||||
// Create handler
|
||||
let handler = Arc::new(ClientHandler::new(
|
||||
Arc::clone(&config),
|
||||
Arc::clone(&stats),
|
||||
Arc::clone(&replay_checker),
|
||||
Arc::clone(&pool),
|
||||
));
|
||||
|
||||
// Start listener
|
||||
let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port)
|
||||
.parse()?;
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Print proxy links
|
||||
print_proxy_links(&config);
|
||||
|
||||
info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging");
|
||||
|
||||
// Main accept loop
|
||||
let accept_loop = async {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer)) => {
|
||||
let handler = Arc::clone(&handler);
|
||||
tokio::spawn(async move {
|
||||
handler.handle(stream, peer).await;
|
||||
});
|
||||
}
|
||||
// Load config
|
||||
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
|
||||
let config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
}
|
||||
// If config doesn't exist, try to create default
|
||||
if std::path::Path::new(&config_path).exists() {
|
||||
error!("Failed to load config: {}", e);
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
let default = ProxyConfig::default();
|
||||
let toml = toml::to_string_pretty(&default).unwrap();
|
||||
std::fs::write(&config_path, toml).unwrap();
|
||||
info!("Created default config at {}", config_path);
|
||||
default
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Graceful shutdown
|
||||
tokio::select! {
|
||||
_ = accept_loop => {}
|
||||
_ = signal::ctrl_c() => {
|
||||
info!("Shutting down...");
|
||||
config.validate()?;
|
||||
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
// CHANGED: Initialize global ReplayChecker here instead of per-connection
|
||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||
|
||||
// Initialize Upstream Manager
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||
|
||||
// Start Health Checks
|
||||
let um_clone = upstream_manager.clone();
|
||||
tokio::spawn(async move {
|
||||
um_clone.run_health_checks().await;
|
||||
});
|
||||
|
||||
// Detect public IP if needed (once at startup)
|
||||
let detected_ip = detect_ip().await;
|
||||
|
||||
// Start Listeners
|
||||
let mut listeners = Vec::new();
|
||||
|
||||
for listener_conf in &config.listeners {
|
||||
let addr = SocketAddr::new(listener_conf.ip, config.port);
|
||||
let options = ListenOptions {
|
||||
ipv6_only: listener_conf.ip.is_ipv6(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
match create_listener(addr, &options) {
|
||||
Ok(socket) => {
|
||||
let listener = TcpListener::from_std(socket.into())?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Determine public IP for tg:// links
|
||||
// 1. Use explicit announce_ip if set
|
||||
// 2. If listening on 0.0.0.0 or ::, use detected public IP
|
||||
// 3. Otherwise use the bind IP
|
||||
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||
ip
|
||||
} else if listener_conf.ip.is_unspecified() {
|
||||
// Try to use detected IP of the same family
|
||||
if listener_conf.ip.is_ipv4() {
|
||||
detected_ip.ipv4.unwrap_or(listener_conf.ip)
|
||||
} else {
|
||||
detected_ip.ipv6.unwrap_or(listener_conf.ip)
|
||||
}
|
||||
} else {
|
||||
listener_conf.ip
|
||||
};
|
||||
|
||||
// Show links for configured users
|
||||
if !config.show_link.is_empty() {
|
||||
info!("--- Proxy Links for {} ---", public_ip);
|
||||
for user_name in &config.show_link {
|
||||
if let Some(secret) = config.users.get(user_name) {
|
||||
info!("User: {}", user_name);
|
||||
|
||||
// Classic
|
||||
if config.modes.classic {
|
||||
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||
public_ip, config.port, secret);
|
||||
}
|
||||
|
||||
// DD (Secure)
|
||||
if config.modes.secure {
|
||||
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||
public_ip, config.port, secret);
|
||||
}
|
||||
|
||||
// EE-TLS (FakeTLS)
|
||||
if config.modes.tls {
|
||||
let domain_hex = hex::encode(&config.tls_domain);
|
||||
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
public_ip, config.port, secret, domain_hex);
|
||||
}
|
||||
} else {
|
||||
warn!("User '{}' specified in show_link not found in users list", user_name);
|
||||
}
|
||||
}
|
||||
info!("-----------------------------------");
|
||||
}
|
||||
|
||||
listeners.push(listener);
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to bind to {}: {}", addr, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
pool.close_all().await;
|
||||
if listeners.is_empty() {
|
||||
error!("No listeners could be started. Exiting.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Accept loop
|
||||
for listener in listeners {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, peer_addr)) => {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = ClientHandler::new(
|
||||
stream,
|
||||
peer_addr,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker // Pass global checker
|
||||
).run().await {
|
||||
// Log only relevant errors
|
||||
// debug!("Connection error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for signal
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => info!("Shutting down..."),
|
||||
Err(e) => error!("Signal error: {}", e),
|
||||
}
|
||||
|
||||
info!("Goodbye!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_proxy_links(config: &ProxyConfig) {
|
||||
println!("\n=== Proxy Links ===\n");
|
||||
|
||||
for (user, secret) in &config.users {
|
||||
if config.modes.tls {
|
||||
let tls_secret = format!(
|
||||
"ee{}{}",
|
||||
secret,
|
||||
hex::encode(config.tls_domain.as_bytes())
|
||||
);
|
||||
println!(
|
||||
"{} (TLS): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, tls_secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.secure {
|
||||
println!(
|
||||
"{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
if config.modes.classic {
|
||||
println!(
|
||||
"{} (Classic): tg://proxy?server=IP&port={}&secret={}",
|
||||
user, config.port, secret
|
||||
);
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("===================\n");
|
||||
}
|
||||
@@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
||||
// ============= Buffer Sizes =============
|
||||
|
||||
/// Default buffer size
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
|
||||
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
|
||||
/// the new buffering strategy for better iOS upload performance.
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
||||
|
||||
/// Small buffer size for bad client handling
|
||||
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::error::{ProxyError, Result, HandshakeResult};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::tls;
|
||||
use crate::stats::{Stats, ReplayChecker};
|
||||
use crate::transport::{ConnectionPool, configure_client_socket};
|
||||
use crate::transport::{configure_client_socket, UpstreamManager};
|
||||
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
||||
use crate::crypto::AesCtr;
|
||||
|
||||
@@ -24,39 +24,54 @@ use super::handshake::{
|
||||
use super::relay::relay_bidirectional;
|
||||
use super::masking::handle_bad_client;
|
||||
|
||||
/// Client connection handler
|
||||
pub struct ClientHandler {
|
||||
/// Client connection handler (builder struct)
|
||||
pub struct ClientHandler;
|
||||
|
||||
/// Running client handler with stream and context
|
||||
pub struct RunningClientHandler {
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
/// Create new client handler
|
||||
/// Create new client handler instance
|
||||
pub fn new(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
pool: Arc<ConnectionPool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker
|
||||
) -> RunningClientHandler {
|
||||
// CHANGED: Removed local creation of ReplayChecker.
|
||||
// It is now passed from main.rs to ensure global replay protection.
|
||||
|
||||
RunningClientHandler {
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
replay_checker,
|
||||
pool,
|
||||
upstream_manager,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a client connection
|
||||
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
|
||||
impl RunningClientHandler {
|
||||
/// Run the client handler
|
||||
pub async fn run(mut self) -> Result<()> {
|
||||
self.stats.increment_connects_all();
|
||||
|
||||
let peer = self.peer;
|
||||
debug!(peer = %peer, "New connection");
|
||||
|
||||
// Configure socket
|
||||
if let Err(e) = configure_client_socket(
|
||||
&stream,
|
||||
&self.stream,
|
||||
self.config.client_keepalive,
|
||||
self.config.client_ack_timeout,
|
||||
) {
|
||||
@@ -66,49 +81,56 @@ impl ClientHandler {
|
||||
// Perform handshake with timeout
|
||||
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
|
||||
|
||||
// Clone stats for error handling block
|
||||
let stats = self.stats.clone();
|
||||
|
||||
let result = timeout(
|
||||
handshake_timeout,
|
||||
self.do_handshake(stream, peer)
|
||||
self.do_handshake()
|
||||
).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
debug!(peer = %peer, "Connection handled successfully");
|
||||
Ok(())
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||
Err(e)
|
||||
}
|
||||
Err(_) => {
|
||||
self.stats.increment_handshake_timeouts();
|
||||
stats.increment_handshake_timeouts();
|
||||
debug!(peer = %peer, "Handshake timeout");
|
||||
Err(ProxyError::TgHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake and relay
|
||||
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
|
||||
async fn do_handshake(mut self) -> Result<()> {
|
||||
// Read first bytes to determine handshake type
|
||||
let mut first_bytes = [0u8; 5];
|
||||
stream.read_exact(&mut first_bytes).await?;
|
||||
self.stream.read_exact(&mut first_bytes).await?;
|
||||
|
||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||
let peer = self.peer;
|
||||
|
||||
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
|
||||
|
||||
if is_tls {
|
||||
self.handle_tls_client(stream, peer, first_bytes).await
|
||||
self.handle_tls_client(first_bytes).await
|
||||
} else {
|
||||
self.handle_direct_client(stream, peer, first_bytes).await
|
||||
self.handle_direct_client(first_bytes).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle TLS-wrapped client
|
||||
async fn handle_tls_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
mut self,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
let peer = self.peer;
|
||||
|
||||
// Read TLS handshake length
|
||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||
|
||||
@@ -117,17 +139,22 @@ impl ClientHandler {
|
||||
if tls_len < 512 {
|
||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
handle_bad_client(self.stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read full TLS handshake
|
||||
let mut handshake = vec![0u8; 5 + tls_len];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Extract fields before consuming self.stream
|
||||
let config = self.config.clone();
|
||||
let replay_checker = self.replay_checker.clone();
|
||||
let stats = self.stats.clone();
|
||||
|
||||
// Split stream for reading/writing
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
// Handle TLS handshake
|
||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||
@@ -135,12 +162,12 @@ impl ClientHandler {
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
@@ -158,44 +185,62 @@ impl ClientHandler {
|
||||
tls_reader,
|
||||
tls_writer,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
true,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
// Handle authenticated client
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
// We can't use self.handle_authenticated_inner because self is partially moved
|
||||
// So we call it as an associated function or method on a new struct,
|
||||
// or just inline the logic / use a static method.
|
||||
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
|
||||
// we should pass them explicitly.
|
||||
|
||||
Self::handle_authenticated_static(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config
|
||||
).await
|
||||
}
|
||||
|
||||
/// Handle direct (non-TLS) client
|
||||
async fn handle_direct_client(
|
||||
&self,
|
||||
mut stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
mut self,
|
||||
first_bytes: [u8; 5],
|
||||
) -> Result<()> {
|
||||
let peer = self.peer;
|
||||
|
||||
// Check if non-TLS modes are enabled
|
||||
if !self.config.modes.classic && !self.config.modes.secure {
|
||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||
self.stats.increment_connects_bad();
|
||||
handle_bad_client(stream, &first_bytes, &self.config).await;
|
||||
handle_bad_client(self.stream, &first_bytes, &self.config).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Read rest of handshake
|
||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||
handshake[..5].copy_from_slice(&first_bytes);
|
||||
stream.read_exact(&mut handshake[5..]).await?;
|
||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
||||
|
||||
// Extract fields
|
||||
let config = self.config.clone();
|
||||
let replay_checker = self.replay_checker.clone();
|
||||
let stats = self.stats.clone();
|
||||
|
||||
// Split stream
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
// Handle MTProto handshake
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
@@ -203,27 +248,36 @@ impl ClientHandler {
|
||||
read_half,
|
||||
write_half,
|
||||
peer,
|
||||
&self.config,
|
||||
&self.replay_checker,
|
||||
&config,
|
||||
&replay_checker,
|
||||
false,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient => {
|
||||
self.stats.increment_connects_bad();
|
||||
stats.increment_connects_bad();
|
||||
return Ok(());
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
};
|
||||
|
||||
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
|
||||
Self::handle_authenticated_static(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
self.upstream_manager,
|
||||
self.stats,
|
||||
self.config
|
||||
).await
|
||||
}
|
||||
|
||||
/// Handle authenticated client - connect to Telegram and relay
|
||||
async fn handle_authenticated_inner<R, W>(
|
||||
&self,
|
||||
/// Static version of handle_authenticated_inner to avoid ownership issues
|
||||
async fn handle_authenticated_static<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -232,13 +286,13 @@ impl ClientHandler {
|
||||
let user = &success.user;
|
||||
|
||||
// Check user limits
|
||||
if let Err(e) = self.check_user_limits(user) {
|
||||
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Get datacenter address
|
||||
let dc_addr = self.get_dc_addr(success.dc_idx)?;
|
||||
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
|
||||
|
||||
info!(
|
||||
user = %user,
|
||||
@@ -246,39 +300,40 @@ impl ClientHandler {
|
||||
dc = success.dc_idx,
|
||||
dc_addr = %dc_addr,
|
||||
proto = ?success.proto_tag,
|
||||
fast_mode = self.config.fast_mode,
|
||||
fast_mode = config.fast_mode,
|
||||
"Connecting to Telegram"
|
||||
);
|
||||
|
||||
// Connect to Telegram
|
||||
let tg_stream = self.pool.get(dc_addr).await?;
|
||||
// Connect to Telegram via UpstreamManager
|
||||
let tg_stream = upstream_manager.connect(dc_addr).await?;
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
|
||||
|
||||
// Perform Telegram handshake and get crypto streams
|
||||
let (tg_reader, tg_writer) = self.do_tg_handshake(
|
||||
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
|
||||
tg_stream,
|
||||
&success,
|
||||
&config,
|
||||
).await?;
|
||||
|
||||
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
|
||||
|
||||
// Update stats
|
||||
self.stats.increment_user_connects(user);
|
||||
self.stats.increment_user_curr_connects(user);
|
||||
stats.increment_user_connects(user);
|
||||
stats.increment_user_curr_connects(user);
|
||||
|
||||
// Relay traffic - передаём Arc::clone(&self.stats)
|
||||
// Relay traffic
|
||||
let relay_result = relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
tg_writer,
|
||||
user,
|
||||
Arc::clone(&self.stats),
|
||||
Arc::clone(&stats),
|
||||
).await;
|
||||
|
||||
// Update stats
|
||||
self.stats.decrement_user_curr_connects(user);
|
||||
stats.decrement_user_curr_connects(user);
|
||||
|
||||
match &relay_result {
|
||||
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
|
||||
@@ -288,26 +343,26 @@ impl ClientHandler {
|
||||
relay_result
|
||||
}
|
||||
|
||||
/// Check user limits (expiration, connection count, data quota)
|
||||
fn check_user_limits(&self, user: &str) -> Result<()> {
|
||||
/// Check user limits (static version)
|
||||
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
|
||||
// Check expiration
|
||||
if let Some(expiration) = self.config.user_expirations.get(user) {
|
||||
if let Some(expiration) = config.user_expirations.get(user) {
|
||||
if chrono::Utc::now() > *expiration {
|
||||
return Err(ProxyError::UserExpired { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection limit
|
||||
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
|
||||
let current = self.stats.get_user_curr_connects(user);
|
||||
if let Some(limit) = config.user_max_tcp_conns.get(user) {
|
||||
let current = stats.get_user_curr_connects(user);
|
||||
if current >= *limit as u64 {
|
||||
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
// Check data quota
|
||||
if let Some(quota) = self.config.user_data_quota.get(user) {
|
||||
let used = self.stats.get_user_total_octets(user);
|
||||
if let Some(quota) = config.user_data_quota.get(user) {
|
||||
let used = stats.get_user_total_octets(user);
|
||||
if used >= *quota {
|
||||
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
|
||||
}
|
||||
@@ -316,11 +371,11 @@ impl ClientHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get datacenter address by index
|
||||
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
|
||||
/// Get datacenter address by index (static version)
|
||||
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||
let idx = (dc_idx.abs() - 1) as usize;
|
||||
|
||||
let datacenters = if self.config.prefer_ipv6 {
|
||||
let datacenters = if config.prefer_ipv6 {
|
||||
&*TG_DATACENTERS_V6
|
||||
} else {
|
||||
&*TG_DATACENTERS_V4
|
||||
@@ -333,19 +388,18 @@ impl ClientHandler {
|
||||
))
|
||||
}
|
||||
|
||||
/// Perform handshake with Telegram server
|
||||
/// Returns crypto reader and writer for TG connection
|
||||
async fn do_tg_handshake(
|
||||
&self,
|
||||
/// Perform handshake with Telegram server (static version)
|
||||
async fn do_tg_handshake_static(
|
||||
mut stream: TcpStream,
|
||||
success: &HandshakeSuccess,
|
||||
config: &ProxyConfig,
|
||||
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
|
||||
// Generate nonce with keys for TG
|
||||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||
success.proto_tag,
|
||||
&success.dec_key, // Client's dec key
|
||||
success.dec_iv,
|
||||
self.config.fast_mode,
|
||||
config.fast_mode,
|
||||
);
|
||||
|
||||
// Encrypt nonce
|
||||
|
||||
@@ -13,7 +13,7 @@ const MASK_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
/// Handle a bad client by forwarding to mask host
|
||||
pub async fn handle_bad_client(
|
||||
mut client: TcpStream,
|
||||
client: TcpStream,
|
||||
initial_data: &[u8],
|
||||
config: &ProxyConfig,
|
||||
) {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
//! Bidirectional Relay
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{debug, trace, warn};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn, info};
|
||||
use crate::error::Result;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
|
||||
const BUFFER_SIZE: usize = 16384;
|
||||
|
||||
// Activity timeout for iOS compatibility (30 minutes)
|
||||
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
|
||||
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
|
||||
|
||||
/// Relay data bidirectionally between client and server
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
@@ -36,15 +44,40 @@ where
|
||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
||||
|
||||
// Client -> Server task
|
||||
// Activity timeout for iOS compatibility
|
||||
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
|
||||
|
||||
// Client -> Server task with activity timeout
|
||||
let c2s = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
let mut last_activity = Instant::now();
|
||||
let mut last_log = Instant::now();
|
||||
|
||||
loop {
|
||||
match client_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
// Read with timeout to prevent infinite hang on iOS
|
||||
let read_result = tokio::time::timeout(
|
||||
activity_timeout,
|
||||
client_reader.read(&mut buf)
|
||||
).await;
|
||||
|
||||
match read_result {
|
||||
// Timeout - no activity for too long
|
||||
Err(_) => {
|
||||
warn!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
idle_secs = last_activity.elapsed().as_secs(),
|
||||
"Activity timeout (C->S) - no data received"
|
||||
);
|
||||
let _ = server_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
|
||||
// Read successful
|
||||
Ok(Ok(0)) => {
|
||||
debug!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
@@ -54,9 +87,11 @@ where
|
||||
let _ = server_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
|
||||
Ok(Ok(n)) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
last_activity = Instant::now();
|
||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
||||
@@ -70,6 +105,19 @@ where
|
||||
"C->S data"
|
||||
);
|
||||
|
||||
// Log activity every 10 seconds for large transfers
|
||||
if last_log.elapsed() > Duration::from_secs(10) {
|
||||
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||
info!(
|
||||
user = %user_c2s,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
rate_kbps = (rate / 1024.0) as u64,
|
||||
"C->S transfer in progress"
|
||||
);
|
||||
last_log = Instant::now();
|
||||
}
|
||||
|
||||
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
||||
break;
|
||||
@@ -79,7 +127,8 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
||||
break;
|
||||
}
|
||||
@@ -87,15 +136,37 @@ where
|
||||
}
|
||||
});
|
||||
|
||||
// Server -> Client task
|
||||
// Server -> Client task with activity timeout
|
||||
let s2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||
let mut total_bytes = 0u64;
|
||||
let mut msg_count = 0u64;
|
||||
let mut last_activity = Instant::now();
|
||||
let mut last_log = Instant::now();
|
||||
|
||||
loop {
|
||||
match server_reader.read(&mut buf).await {
|
||||
Ok(0) => {
|
||||
// Read with timeout to prevent infinite hang on iOS
|
||||
let read_result = tokio::time::timeout(
|
||||
activity_timeout,
|
||||
server_reader.read(&mut buf)
|
||||
).await;
|
||||
|
||||
match read_result {
|
||||
// Timeout - no activity for too long
|
||||
Err(_) => {
|
||||
warn!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
idle_secs = last_activity.elapsed().as_secs(),
|
||||
"Activity timeout (S->C) - no data received"
|
||||
);
|
||||
let _ = client_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
|
||||
// Read successful
|
||||
Ok(Ok(0)) => {
|
||||
debug!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
@@ -105,9 +176,11 @@ where
|
||||
let _ = client_writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
|
||||
Ok(Ok(n)) => {
|
||||
total_bytes += n as u64;
|
||||
msg_count += 1;
|
||||
last_activity = Instant::now();
|
||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||
|
||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
||||
@@ -121,6 +194,19 @@ where
|
||||
"S->C data"
|
||||
);
|
||||
|
||||
// Log activity every 10 seconds for large transfers
|
||||
if last_log.elapsed() > Duration::from_secs(10) {
|
||||
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||
info!(
|
||||
user = %user_s2c,
|
||||
total_bytes = total_bytes,
|
||||
msgs = msg_count,
|
||||
rate_kbps = (rate / 1024.0) as u64,
|
||||
"S->C transfer in progress"
|
||||
);
|
||||
last_log = Instant::now();
|
||||
}
|
||||
|
||||
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
||||
break;
|
||||
@@ -130,7 +216,8 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -11,8 +11,9 @@ use std::sync::Arc;
|
||||
|
||||
// ============= Configuration =============
|
||||
|
||||
/// Default buffer size (64KB - good for MTProto)
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
|
||||
/// Default buffer size
|
||||
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
||||
|
||||
/// Default maximum number of pooled buffers
|
||||
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -149,9 +149,9 @@ pub trait FrameCodec: Send + Sync {
|
||||
/// Create a frame codec for the given protocol tag
|
||||
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => Box::new(super::frame_codec::AbridgedCodec::new()),
|
||||
ProtoTag::Intermediate => Box::new(super::frame_codec::IntermediateCodec::new()),
|
||||
ProtoTag::Secure => Box::new(super::frame_codec::SecureCodec::new()),
|
||||
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
|
||||
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
|
||||
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,36 @@
|
||||
//! Fake TLS 1.3 stream wrappers
|
||||
//!
|
||||
//! This module provides stateful async stream wrappers that handle
|
||||
//! TLS record framing with proper partial read/write handling.
|
||||
//! This module provides stateful async stream wrappers that handle TLS record
|
||||
//! framing with proper partial read/write handling.
|
||||
//!
|
||||
//! These are "fake" TLS streams - they wrap data in valid TLS 1.3
|
||||
//! Application Data records but don't perform actual TLS encryption.
|
||||
//! The actual encryption is handled by the crypto layer underneath.
|
||||
//! These are "fake" TLS streams:
|
||||
//! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data).
|
||||
//! - We DO NOT perform real TLS handshake/encryption.
|
||||
//! - Real crypto for MTProto is handled by the crypto layer underneath.
|
||||
//!
|
||||
//! Why do we need this?
|
||||
//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for
|
||||
//! domain fronting / traffic camouflage. iOS Telegram clients are known to
|
||||
//! produce slightly different TLS record sizing patterns than Android/Desktop,
|
||||
//! including records that exceed 16384 payload bytes by a small overhead.
|
||||
//!
|
||||
//! Key design principles:
|
||||
//! - Explicit state machines for all async operations
|
||||
//! - Never lose data on partial reads
|
||||
//! - Atomic TLS record formation for writes
|
||||
//! - Proper handling of all TLS record types
|
||||
//!
|
||||
//! Important nuance (Telegram FakeTLS):
|
||||
//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes.
|
||||
//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3
|
||||
//! uses AEAD and can include tag/overhead/padding.
|
||||
//! - Telegram FakeTLS clients (notably iOS) may send Application Data records
|
||||
//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE.
|
||||
//!
|
||||
//! If you reject those (e.g. validate length <= 16384), you will see errors like:
|
||||
//! "TLS record too large: 16408 bytes"
|
||||
//! and uploads from iOS will break (media/file sending), while small traffic
|
||||
//! may still work.
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use std::io::{self, Error, ErrorKind, Result};
|
||||
@@ -20,25 +39,29 @@ use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
|
||||
use crate::protocol::constants::{
|
||||
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE,
|
||||
TLS_VERSION,
|
||||
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
|
||||
MAX_TLS_CHUNK_SIZE,
|
||||
};
|
||||
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
|
||||
|
||||
// ============= Constants =============
|
||||
|
||||
/// TLS record header size
|
||||
/// TLS record header size (type + version + length)
|
||||
const TLS_HEADER_SIZE: usize = 5;
|
||||
|
||||
/// Maximum TLS record payload size (16KB as per TLS spec)
|
||||
/// Maximum TLS fragment size per spec (plaintext fragment).
|
||||
/// We use this for *outgoing* chunking, because we build plain ApplicationData records.
|
||||
const MAX_TLS_PAYLOAD: usize = 16384;
|
||||
|
||||
/// Maximum pending write buffer
|
||||
/// Maximum pending write buffer for one record remainder.
|
||||
/// Note: we never queue unlimited amount of data here; state holds at most one record.
|
||||
const MAX_PENDING_WRITE: usize = 64 * 1024;
|
||||
|
||||
// ============= TLS Record Types =============
|
||||
|
||||
/// Parsed TLS record header
|
||||
/// Parsed TLS record header (5 bytes)
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct TlsRecordHeader {
|
||||
/// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.)
|
||||
@@ -50,22 +73,27 @@ struct TlsRecordHeader {
|
||||
}
|
||||
|
||||
impl TlsRecordHeader {
|
||||
/// Parse header from 5 bytes
|
||||
/// Parse header from exactly 5 bytes.
|
||||
///
|
||||
/// This currently never returns None, but is kept as Option to allow future
|
||||
/// stricter parsing rules without changing callers.
|
||||
fn parse(header: &[u8; 5]) -> Option<Self> {
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]);
|
||||
|
||||
Some(Self {
|
||||
record_type,
|
||||
version,
|
||||
length,
|
||||
})
|
||||
Some(Self { record_type, version, length })
|
||||
}
|
||||
|
||||
/// Validate the header
|
||||
/// Validate the header.
|
||||
///
|
||||
/// Nuances:
|
||||
/// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01),
|
||||
/// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03).
|
||||
/// - For Application Data, Telegram FakeTLS may send payload length up to
|
||||
/// MAX_TLS_CHUNK_SIZE (16384 + 24).
|
||||
/// - For other record types we keep stricter bounds to avoid memory abuse.
|
||||
fn validate(&self) -> Result<()> {
|
||||
// Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others)
|
||||
// Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303).
|
||||
if self.version != [0x03, 0x01] && self.version != TLS_VERSION {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
@@ -73,27 +101,36 @@ impl TlsRecordHeader {
|
||||
));
|
||||
}
|
||||
|
||||
// Check length
|
||||
if self.length as usize > MAX_TLS_RECORD_SIZE {
|
||||
let len = self.length as usize;
|
||||
|
||||
// Length checks depend on record type.
|
||||
// Telegram FakeTLS: ApplicationData length may be 16384 + 24.
|
||||
match self.record_type {
|
||||
TLS_RECORD_APPLICATION => {
|
||||
if len > MAX_TLS_CHUNK_SIZE {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("TLS record too large: {} bytes", self.length),
|
||||
format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// ChangeCipherSpec/Alert/Handshake should never be that large for our usage
|
||||
// (post-handshake we don't expect Handshake at all).
|
||||
// Keep strict to reduce attack surface.
|
||||
_ => {
|
||||
if len > MAX_TLS_PAYLOAD {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if this is an application data record
|
||||
fn is_application_data(&self) -> bool {
|
||||
self.record_type == TLS_RECORD_APPLICATION
|
||||
}
|
||||
|
||||
/// Check if this is a change cipher spec record (should be skipped)
|
||||
fn is_change_cipher_spec(&self) -> bool {
|
||||
self.record_type == TLS_RECORD_CHANGE_CIPHER
|
||||
}
|
||||
|
||||
/// Build header bytes
|
||||
fn to_bytes(&self) -> [u8; 5] {
|
||||
[
|
||||
@@ -120,25 +157,20 @@ enum TlsReaderState {
|
||||
header: HeaderBuffer<TLS_HEADER_SIZE>,
|
||||
},
|
||||
|
||||
/// Reading the TLS record body
|
||||
/// Reading the TLS record body (payload)
|
||||
ReadingBody {
|
||||
/// Parsed record type
|
||||
record_type: u8,
|
||||
/// Total body length
|
||||
length: usize,
|
||||
/// Buffer for body data
|
||||
buffer: BytesMut,
|
||||
},
|
||||
|
||||
/// Have decrypted data ready to yield to caller
|
||||
/// Have buffered data ready to yield to caller
|
||||
Yielding {
|
||||
/// Buffer containing data to yield
|
||||
buffer: YieldBuffer,
|
||||
},
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
/// The error that caused poisoning
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
}
|
||||
@@ -165,12 +197,13 @@ impl StreamState for TlsReaderState {
|
||||
|
||||
// ============= FakeTlsReader =============
|
||||
|
||||
/// Reader that unwraps TLS 1.3 records with proper state machine
|
||||
/// Reader that unwraps TLS records (FakeTLS).
|
||||
///
|
||||
/// This reader handles partial reads correctly by maintaining internal state
|
||||
/// and never losing any data that has been read from upstream.
|
||||
/// This wrapper is responsible ONLY for TLS record framing and skipping
|
||||
/// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed
|
||||
/// as-is to upper layers (crypto stream).
|
||||
///
|
||||
/// # State Machine
|
||||
/// State machine overview:
|
||||
///
|
||||
/// ┌──────────┐ ┌───────────────┐
|
||||
/// │ Idle │ -----------------> │ ReadingHeader │
|
||||
@@ -178,103 +211,69 @@ impl StreamState for TlsReaderState {
|
||||
/// ▲ │
|
||||
/// │ header complete
|
||||
/// │ │
|
||||
/// │ │
|
||||
/// │ ▼
|
||||
/// │ ┌───────────────┐
|
||||
/// │ skip record │ ReadingBody │
|
||||
/// │ <-------- (CCS) -------- │ │
|
||||
/// │ └───────┬───────┘
|
||||
/// │ │
|
||||
/// │ body complete
|
||||
/// │ drained │
|
||||
/// │ <-----------------┐ │
|
||||
/// │ │ ┌───────────────┐
|
||||
/// │ └----- │ Yielding │
|
||||
/// │ ▼
|
||||
/// │ ┌───────────────┐
|
||||
/// │ │ Yielding │
|
||||
/// │ └───────────────┘
|
||||
/// │
|
||||
/// │ errors /w any state
|
||||
/// │
|
||||
/// │ errors / w any state
|
||||
/// ▼
|
||||
/// ┌───────────────────────────────────────────────┐
|
||||
/// │ Poisoned │
|
||||
/// └───────────────────────────────────────────────┘
|
||||
///
|
||||
/// NOTE: We must correctly handle partial reads from upstream:
|
||||
/// - do not assume header arrives in one poll
|
||||
/// - do not assume body arrives in one poll
|
||||
/// - never lose already-read bytes
|
||||
pub struct FakeTlsReader<R> {
|
||||
/// Upstream reader
|
||||
upstream: R,
|
||||
/// Current state
|
||||
state: TlsReaderState,
|
||||
}
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
/// Create new fake TLS reader
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsReaderState::Idle,
|
||||
}
|
||||
Self { upstream, state: TlsReaderState::Idle }
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_ref(&self) -> &R { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
pub fn into_inner(self) -> R { self.upstream }
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
/// Check if stream is in poisoned state
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
|
||||
/// Get current state name (for debugging)
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
/// Transition to poisoned state
|
||||
fn poison(&mut self, error: io::Error) {
|
||||
self.state = TlsReaderState::Poisoned { error: Some(error) };
|
||||
}
|
||||
|
||||
/// Take error from poisoned state
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsReaderState::Poisoned { error } => {
|
||||
error.take().unwrap_or_else(|| {
|
||||
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
||||
})
|
||||
}
|
||||
}),
|
||||
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of polling for header completion
|
||||
enum HeaderPollResult {
|
||||
/// Need more data
|
||||
Pending,
|
||||
/// EOF at record boundary (clean close)
|
||||
Eof,
|
||||
/// Header complete, parsed successfully
|
||||
Complete(TlsRecordHeader),
|
||||
/// Error occurred
|
||||
Error(io::Error),
|
||||
}
|
||||
|
||||
/// Result of polling for body completion
|
||||
enum BodyPollResult {
|
||||
/// Need more data
|
||||
Pending,
|
||||
/// Body complete
|
||||
Complete(Bytes),
|
||||
/// Error occurred
|
||||
Error(io::Error),
|
||||
}
|
||||
|
||||
@@ -291,7 +290,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
let state = std::mem::replace(&mut this.state, TlsReaderState::Idle);
|
||||
|
||||
match state {
|
||||
// Poisoned state - return error
|
||||
// Poisoned state: always return the stored error
|
||||
TlsReaderState::Poisoned { error } => {
|
||||
this.state = TlsReaderState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
@@ -300,20 +299,18 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
// Have buffered data to yield
|
||||
// Yield buffered plaintext to caller
|
||||
TlsReaderState::Yielding { mut buffer } => {
|
||||
if buf.remaining() == 0 {
|
||||
this.state = TlsReaderState::Yielding { buffer };
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Copy as much as possible to output
|
||||
let to_copy = buffer.remaining().min(buf.remaining());
|
||||
let dst = buf.initialize_unfilled_to(to_copy);
|
||||
let copied = buffer.copy_to(dst);
|
||||
buf.advance(copied);
|
||||
|
||||
// If buffer is drained, transition to Idle
|
||||
if buffer.is_empty() {
|
||||
this.state = TlsReaderState::Idle;
|
||||
} else {
|
||||
@@ -323,23 +320,21 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Ready to read a new TLS record
|
||||
// Start reading new record
|
||||
TlsReaderState::Idle => {
|
||||
if buf.remaining() == 0 {
|
||||
this.state = TlsReaderState::Idle;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Start reading header
|
||||
this.state = TlsReaderState::ReadingHeader {
|
||||
header: HeaderBuffer::new(),
|
||||
};
|
||||
// Continue to ReadingHeader
|
||||
// loop continues and will handle ReadingHeader
|
||||
}
|
||||
|
||||
// Reading TLS record header
|
||||
// Read TLS header (5 bytes)
|
||||
TlsReaderState::ReadingHeader { mut header } => {
|
||||
// Poll to fill header
|
||||
let result = poll_read_header(&mut this.upstream, cx, &mut header);
|
||||
|
||||
match result {
|
||||
@@ -348,6 +343,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
return Poll::Pending;
|
||||
}
|
||||
HeaderPollResult::Eof => {
|
||||
// Clean EOF at record boundary
|
||||
this.state = TlsReaderState::Idle;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
@@ -356,15 +352,12 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
HeaderPollResult::Complete(parsed) => {
|
||||
// Validate header
|
||||
if let Err(e) = parsed.validate() {
|
||||
this.poison(Error::new(e.kind(), e.to_string()));
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
|
||||
let length = parsed.length as usize;
|
||||
|
||||
// Transition to reading body
|
||||
this.state = TlsReaderState::ReadingBody {
|
||||
record_type: parsed.record_type,
|
||||
length,
|
||||
@@ -374,7 +367,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
}
|
||||
}
|
||||
|
||||
// Reading TLS record body
|
||||
// Read TLS payload
|
||||
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
|
||||
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
||||
|
||||
@@ -388,15 +381,15 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
BodyPollResult::Complete(data) => {
|
||||
// Handle different record types
|
||||
match record_type {
|
||||
TLS_RECORD_CHANGE_CIPHER => {
|
||||
// Skip Change Cipher Spec, read next record
|
||||
// CCS is expected in some clients, ignore it.
|
||||
this.state = TlsReaderState::Idle;
|
||||
continue;
|
||||
}
|
||||
|
||||
TLS_RECORD_APPLICATION => {
|
||||
// Application data - yield to caller
|
||||
// This is what we actually want.
|
||||
if data.is_empty() {
|
||||
this.state = TlsReaderState::Idle;
|
||||
continue;
|
||||
@@ -405,25 +398,26 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
this.state = TlsReaderState::Yielding {
|
||||
buffer: YieldBuffer::new(data),
|
||||
};
|
||||
// Continue to yield
|
||||
// loop continues and will yield immediately
|
||||
}
|
||||
|
||||
TLS_RECORD_ALERT => {
|
||||
// TLS Alert - treat as EOF
|
||||
// Treat TLS alert as EOF-like termination.
|
||||
this.state = TlsReaderState::Idle;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
TLS_RECORD_HANDSHAKE => {
|
||||
let err = Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unexpected TLS handshake record"
|
||||
);
|
||||
// After FakeTLS handshake is done, we do not expect any Handshake records.
|
||||
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
|
||||
this.poison(Error::new(err.kind(), err.to_string()));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
_ => {
|
||||
let err = Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("unknown TLS record type: 0x{:02x}", record_type)
|
||||
format!("unknown TLS record type: 0x{:02x}", record_type),
|
||||
);
|
||||
this.poison(Error::new(err.kind(), err.to_string()));
|
||||
return Poll::Ready(Err(err));
|
||||
@@ -459,8 +453,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
||||
} else {
|
||||
return HeaderPollResult::Error(Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
format!("unexpected EOF in TLS header (got {} of 5 bytes)",
|
||||
header.as_slice().len())
|
||||
format!(
|
||||
"unexpected EOF in TLS header (got {} of 5 bytes)",
|
||||
header.as_slice().len()
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -469,14 +465,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
||||
}
|
||||
}
|
||||
|
||||
// Parse header
|
||||
let header_bytes = *header.as_array();
|
||||
match TlsRecordHeader::parse(&header_bytes) {
|
||||
Some(h) => HeaderPollResult::Complete(h),
|
||||
None => HeaderPollResult::Error(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"failed to parse TLS header"
|
||||
)),
|
||||
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -487,10 +479,12 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
||||
buffer: &mut BytesMut,
|
||||
target_len: usize,
|
||||
) -> BodyPollResult {
|
||||
// NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime
|
||||
// issues with BytesMut spare capacity and ReadBuf across polls.
|
||||
// It's safe and correct; optimization is possible if needed.
|
||||
while buffer.len() < target_len {
|
||||
let remaining = target_len - buffer.len();
|
||||
|
||||
// Read into a temporary buffer
|
||||
let mut temp = vec![0u8; remaining.min(8192)];
|
||||
let mut read_buf = ReadBuf::new(&mut temp);
|
||||
|
||||
@@ -502,8 +496,11 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
||||
if n == 0 {
|
||||
return BodyPollResult::Error(Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
format!("unexpected EOF in TLS body (got {} of {} bytes)",
|
||||
buffer.len(), target_len)
|
||||
format!(
|
||||
"unexpected EOF in TLS body (got {} of {} bytes)",
|
||||
buffer.len(),
|
||||
target_len
|
||||
),
|
||||
));
|
||||
}
|
||||
buffer.extend_from_slice(&temp[..n]);
|
||||
@@ -515,10 +512,9 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||
/// Read exactly n bytes through TLS layer
|
||||
/// Read exactly n bytes through TLS layer.
|
||||
///
|
||||
/// This is a convenience method that accumulates data across
|
||||
/// multiple TLS records until exactly n bytes are available.
|
||||
/// This accumulates data across multiple TLS ApplicationData records.
|
||||
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
|
||||
if self.is_poisoned() {
|
||||
return Err(self.take_poison_error());
|
||||
@@ -533,7 +529,7 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||
if read == 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
format!("expected {} bytes, got {}", n, result.len())
|
||||
format!("expected {} bytes, got {}", n, result.len()),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -546,23 +542,19 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||
|
||||
// ============= FakeTlsWriter State =============
|
||||
|
||||
/// State machine states for FakeTlsWriter
|
||||
#[derive(Debug)]
|
||||
enum TlsWriterState {
|
||||
/// Ready to accept new data
|
||||
Idle,
|
||||
|
||||
/// Writing a complete TLS record
|
||||
/// Writing a complete TLS record (header + body), possibly partially
|
||||
WritingRecord {
|
||||
/// Complete record (header + body) to write
|
||||
record: WriteBuffer,
|
||||
/// Original payload size (for return value calculation)
|
||||
payload_size: usize,
|
||||
},
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
/// The error that caused poisoning
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
}
|
||||
@@ -587,94 +579,46 @@ impl StreamState for TlsWriterState {
|
||||
|
||||
// ============= FakeTlsWriter =============
|
||||
|
||||
/// Writer that wraps data in TLS 1.3 records with proper state machine
|
||||
/// Writer that wraps bytes into TLS 1.3 Application Data records.
|
||||
///
|
||||
/// This writer handles partial writes correctly by:
|
||||
/// - Building complete TLS records before writing
|
||||
/// - Maintaining internal state for partial record writes
|
||||
/// - Never splitting a record mid-write to upstream
|
||||
///
|
||||
/// # State Machine
|
||||
///
|
||||
/// ┌──────────┐ write ┌─────────────────┐
|
||||
/// │ Idle │ -------------> │ WritingRecord │
|
||||
/// │ │ <------------- │ │
|
||||
/// └──────────┘ complete └─────────────────┘
|
||||
/// │ │
|
||||
/// │ < errors > │
|
||||
/// │ │
|
||||
/// ┌─────────────────────────────────────────────┐
|
||||
/// │ Poisoned │
|
||||
/// └─────────────────────────────────────────────┘
|
||||
///
|
||||
/// # Record Formation
|
||||
///
|
||||
/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes.
|
||||
/// Each record has a 5-byte header prepended.
|
||||
/// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD).
|
||||
/// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it.
|
||||
/// If you want to be more camouflage-accurate later, you could add optional padding
|
||||
/// to produce records sized closer to MAX_TLS_CHUNK_SIZE.
|
||||
pub struct FakeTlsWriter<W> {
|
||||
/// Upstream writer
|
||||
upstream: W,
|
||||
/// Current state
|
||||
state: TlsWriterState,
|
||||
}
|
||||
|
||||
impl<W> FakeTlsWriter<W> {
|
||||
/// Create new fake TLS writer
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsWriterState::Idle,
|
||||
}
|
||||
Self { upstream, state: TlsWriterState::Idle }
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_ref(&self) -> &W { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
pub fn into_inner(self) -> W { self.upstream }
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
/// Check if stream is in poisoned state
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
|
||||
/// Get current state name (for debugging)
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
/// Check if there's a pending record to write
|
||||
pub fn has_pending(&self) -> bool {
|
||||
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
|
||||
}
|
||||
|
||||
/// Transition to poisoned state
|
||||
fn poison(&mut self, error: io::Error) {
|
||||
self.state = TlsWriterState::Poisoned { error: Some(error) };
|
||||
}
|
||||
|
||||
/// Take error from poisoned state
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsWriterState::Poisoned { error } => {
|
||||
error.take().unwrap_or_else(|| {
|
||||
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
||||
})
|
||||
}
|
||||
}),
|
||||
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a TLS Application Data record
|
||||
fn build_record(data: &[u8]) -> BytesMut {
|
||||
let header = TlsRecordHeader {
|
||||
record_type: TLS_RECORD_APPLICATION,
|
||||
@@ -689,18 +633,13 @@ impl<W> FakeTlsWriter<W> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of flushing pending record
|
||||
enum FlushResult {
|
||||
/// All data flushed, returns payload size
|
||||
Complete(usize),
|
||||
/// Need to wait for upstream
|
||||
Pending,
|
||||
/// Error occurred
|
||||
Error(io::Error),
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
/// Try to flush pending record to upstream (standalone logic)
|
||||
fn poll_flush_record_inner(
|
||||
upstream: &mut W,
|
||||
cx: &mut Context<'_>,
|
||||
@@ -710,19 +649,14 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
let data = record.pending();
|
||||
match Pin::new(&mut *upstream).poll_write(cx, data) {
|
||||
Poll::Pending => return FlushResult::Pending,
|
||||
|
||||
Poll::Ready(Err(e)) => return FlushResult::Error(e),
|
||||
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return FlushResult::Error(Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"upstream returned 0 bytes written"
|
||||
"upstream returned 0 bytes written",
|
||||
));
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(n)) => {
|
||||
record.advance(n);
|
||||
}
|
||||
Poll::Ready(Ok(n)) => record.advance(n),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -738,7 +672,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Take ownership of state
|
||||
// Take ownership of state to avoid borrow conflicts.
|
||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||
|
||||
match state {
|
||||
@@ -751,7 +685,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
}
|
||||
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
// Continue flushing existing record
|
||||
// Finish writing previous record before accepting new bytes.
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||
@@ -763,7 +697,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
// Fall through to handle new write
|
||||
// continue to accept new buf below
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -782,19 +716,18 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
let chunk_size = buf.len().min(MAX_TLS_PAYLOAD);
|
||||
let chunk = &buf[..chunk_size];
|
||||
|
||||
// Build the complete record
|
||||
// Build the complete record (header + payload)
|
||||
let record_data = Self::build_record(chunk);
|
||||
|
||||
// Try to write directly first
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
|
||||
Poll::Ready(Ok(n)) if n == record_data.len() => {
|
||||
// Complete record written
|
||||
Poll::Ready(Ok(chunk_size))
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(n)) => {
|
||||
// Partial write - buffer the rest
|
||||
// Partial write of the record: store remainder.
|
||||
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
||||
// record_data length is <= 16389, fits MAX_PENDING_WRITE
|
||||
let _ = write_buffer.extend(&record_data[n..]);
|
||||
|
||||
this.state = TlsWriterState::WritingRecord {
|
||||
@@ -802,7 +735,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
payload_size: chunk_size,
|
||||
};
|
||||
|
||||
// We've accepted chunk_size bytes from caller
|
||||
// We have accepted chunk_size bytes from caller.
|
||||
Poll::Ready(Ok(chunk_size))
|
||||
}
|
||||
|
||||
@@ -812,7 +745,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
}
|
||||
|
||||
Poll::Pending => {
|
||||
// Buffer the entire record
|
||||
// Buffer entire record and report success for this chunk.
|
||||
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
||||
let _ = write_buffer.extend(&record_data);
|
||||
|
||||
@@ -821,10 +754,9 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
payload_size: chunk_size,
|
||||
};
|
||||
|
||||
// Wake to try again
|
||||
// Wake to retry flushing soon.
|
||||
cx.waker().wake_by_ref();
|
||||
|
||||
// We've accepted chunk_size bytes from caller
|
||||
Poll::Ready(Ok(chunk_size))
|
||||
}
|
||||
}
|
||||
@@ -833,7 +765,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Take ownership of state
|
||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||
|
||||
match state {
|
||||
@@ -866,48 +797,33 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
}
|
||||
}
|
||||
|
||||
// Flush upstream
|
||||
Pin::new(&mut this.upstream).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Take ownership of state
|
||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||
|
||||
match state {
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
// Try to flush pending (best effort)
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
// Can't complete flush, continue with shutdown anyway
|
||||
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
|
||||
// Best-effort flush (do not block shutdown forever).
|
||||
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
FlushResult::Error(_) => {
|
||||
// Ignore errors during shutdown
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown upstream
|
||||
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
/// Write all data wrapped in TLS records (async method)
|
||||
/// Write all data wrapped in TLS records.
|
||||
///
|
||||
/// This convenience method handles chunking large data into
|
||||
/// multiple TLS records automatically.
|
||||
/// Convenience method that chunks into <= 16384 records.
|
||||
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
|
||||
let mut written = 0;
|
||||
while written < data.len() {
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
pub mod pool;
|
||||
pub mod proxy_protocol;
|
||||
pub mod socket;
|
||||
pub mod socks;
|
||||
pub mod upstream;
|
||||
|
||||
pub use pool::ConnectionPool;
|
||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||
pub use socket::*;
|
||||
pub use socks::*;
|
||||
pub use upstream::UpstreamManager;
|
||||
@@ -1,7 +1,7 @@
|
||||
//! TCP Socket Configuration
|
||||
|
||||
use std::io::Result;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
|
||||
@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
}
|
||||
|
||||
// Set buffer sizes
|
||||
set_buffer_sizes(&socket, 65536, 65536)?;
|
||||
// CHANGED: Removed manual buffer size setting (was 256KB).
|
||||
// Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
|
||||
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set socket buffer sizes
|
||||
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
|
||||
// These may fail on some systems, so we ignore errors
|
||||
let _ = socket.set_recv_buffer_size(recv);
|
||||
let _ = socket.set_send_buffer_size(send);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configure socket for accepting client connections
|
||||
pub fn configure_client_socket(
|
||||
stream: &TcpStream,
|
||||
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
|
||||
socket.set_tcp_keepalive(&keepalive)?;
|
||||
|
||||
// Set TCP user timeout (Linux only)
|
||||
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
|
||||
// is implemented in relay_bidirectional instead
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
@@ -93,6 +88,11 @@ pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
|
||||
|
||||
/// Create a new TCP socket for outgoing connections
|
||||
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||
create_outgoing_socket_bound(addr, None)
|
||||
}
|
||||
|
||||
/// Create a new TCP socket for outgoing connections, optionally bound to a specific interface
|
||||
pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option<IpAddr>) -> Result<Socket> {
|
||||
let domain = if addr.is_ipv4() {
|
||||
Domain::IPV4
|
||||
} else {
|
||||
@@ -107,9 +107,16 @@ pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
|
||||
// Disable Nagle
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
if let Some(bind_ip) = bind_addr {
|
||||
let bind_sock_addr = SocketAddr::new(bind_ip, 0);
|
||||
socket.bind(&bind_sock_addr.into())?;
|
||||
debug!("Bound outgoing socket to {}", bind_ip);
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
|
||||
/// Get local address of a socket
|
||||
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||
stream.local_addr().ok()
|
||||
|
||||
145
src/transport/socks.rs
Normal file
145
src/transport/socks.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
//! SOCKS4/5 Client Implementation
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
pub async fn connect_socks4(
|
||||
stream: &mut TcpStream,
|
||||
target: SocketAddr,
|
||||
user_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let ip = match target.ip() {
|
||||
IpAddr::V4(ip) => ip,
|
||||
IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())),
|
||||
};
|
||||
|
||||
let port = target.port();
|
||||
let user = user_id.unwrap_or("").as_bytes();
|
||||
|
||||
// VN (4) | CD (1) | DSTPORT (2) | DSTIP (4) | USERID (variable) | NULL (1)
|
||||
let mut buf = Vec::with_capacity(9 + user.len());
|
||||
buf.push(4); // VN
|
||||
buf.push(1); // CD (CONNECT)
|
||||
buf.extend_from_slice(&port.to_be_bytes());
|
||||
buf.extend_from_slice(&ip.octets());
|
||||
buf.extend_from_slice(user);
|
||||
buf.push(0); // NULL
|
||||
|
||||
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
// Response: VN (1) | CD (1) | DSTPORT (2) | DSTIP (4)
|
||||
let mut resp = [0u8; 8];
|
||||
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if resp[1] != 90 {
|
||||
return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1])));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect_socks5(
|
||||
stream: &mut TcpStream,
|
||||
target: SocketAddr,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
) -> Result<()> {
|
||||
// 1. Auth negotiation
|
||||
// VER (1) | NMETHODS (1) | METHODS (variable)
|
||||
let mut methods = vec![0u8]; // No auth
|
||||
if username.is_some() {
|
||||
methods.push(2u8); // Username/Password
|
||||
}
|
||||
|
||||
let mut buf = vec![5u8, methods.len() as u8];
|
||||
buf.extend_from_slice(&methods);
|
||||
|
||||
stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
let mut resp = [0u8; 2];
|
||||
stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if resp[0] != 5 {
|
||||
return Err(ProxyError::Proxy("Invalid SOCKS5 version".to_string()));
|
||||
}
|
||||
|
||||
match resp[1] {
|
||||
0 => {}, // No auth
|
||||
2 => {
|
||||
// Username/Password auth
|
||||
if let (Some(u), Some(p)) = (username, password) {
|
||||
let u_bytes = u.as_bytes();
|
||||
let p_bytes = p.as_bytes();
|
||||
|
||||
let mut auth_buf = Vec::with_capacity(3 + u_bytes.len() + p_bytes.len());
|
||||
auth_buf.push(1); // VER
|
||||
auth_buf.push(u_bytes.len() as u8);
|
||||
auth_buf.extend_from_slice(u_bytes);
|
||||
auth_buf.push(p_bytes.len() as u8);
|
||||
auth_buf.extend_from_slice(p_bytes);
|
||||
|
||||
stream.write_all(&auth_buf).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
let mut auth_resp = [0u8; 2];
|
||||
stream.read_exact(&mut auth_resp).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if auth_resp[1] != 0 {
|
||||
return Err(ProxyError::Proxy("SOCKS5 authentication failed".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(ProxyError::Proxy("SOCKS5 server requires authentication".to_string()));
|
||||
}
|
||||
},
|
||||
_ => return Err(ProxyError::Proxy("Unsupported SOCKS5 auth method".to_string())),
|
||||
}
|
||||
|
||||
// 2. Connection request
|
||||
// VER (1) | CMD (1) | RSV (1) | ATYP (1) | DST.ADDR (variable) | DST.PORT (2)
|
||||
let mut req = vec![5u8, 1u8, 0u8]; // CONNECT
|
||||
|
||||
match target {
|
||||
SocketAddr::V4(v4) => {
|
||||
req.push(1u8); // IPv4
|
||||
req.extend_from_slice(&v4.ip().octets());
|
||||
},
|
||||
SocketAddr::V6(v6) => {
|
||||
req.push(4u8); // IPv6
|
||||
req.extend_from_slice(&v6.ip().octets());
|
||||
},
|
||||
}
|
||||
|
||||
req.extend_from_slice(&target.port().to_be_bytes());
|
||||
|
||||
stream.write_all(&req).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
// Response
|
||||
let mut head = [0u8; 4];
|
||||
stream.read_exact(&mut head).await.map_err(|e| ProxyError::Io(e))?;
|
||||
|
||||
if head[1] != 0 {
|
||||
return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1])));
|
||||
}
|
||||
|
||||
// Skip address part of response
|
||||
match head[3] {
|
||||
1 => { // IPv4
|
||||
let mut addr = [0u8; 4 + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
3 => { // Domain
|
||||
let mut len = [0u8; 1];
|
||||
stream.read_exact(&mut len).await.map_err(|e| ProxyError::Io(e))?;
|
||||
let mut addr = vec![0u8; len[0] as usize + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
4 => { // IPv6
|
||||
let mut addr = [0u8; 16 + 2];
|
||||
stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?;
|
||||
},
|
||||
_ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
259
src/transport/upstream.rs
Normal file
259
src/transport/upstream.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
//! Upstream Management
|
||||
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
use rand::Rng;
|
||||
use tracing::{debug, warn, error, info};
|
||||
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::error::{Result, ProxyError};
|
||||
use crate::transport::socket::create_outgoing_socket_bound;
|
||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UpstreamState {
|
||||
config: UpstreamConfig,
|
||||
healthy: bool,
|
||||
fails: u32,
|
||||
last_check: std::time::Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UpstreamManager {
|
||||
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||
}
|
||||
|
||||
impl UpstreamManager {
|
||||
pub fn new(configs: Vec<UpstreamConfig>) -> Self {
|
||||
let states = configs.into_iter()
|
||||
.filter(|c| c.enabled)
|
||||
.map(|c| UpstreamState {
|
||||
config: c,
|
||||
healthy: true, // Optimistic start
|
||||
fails: 0,
|
||||
last_check: std::time::Instant::now(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
upstreams: Arc::new(RwLock::new(states)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream using Weighted Round Robin (simplified)
|
||||
async fn select_upstream(&self) -> Option<usize> {
|
||||
let upstreams = self.upstreams.read().await;
|
||||
if upstreams.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let healthy_indices: Vec<usize> = upstreams.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, u)| u.healthy)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
// If all unhealthy, try any random one
|
||||
return Some(rand::thread_rng().gen_range(0..upstreams.len()));
|
||||
}
|
||||
|
||||
// Weighted selection
|
||||
let total_weight: u32 = healthy_indices.iter()
|
||||
.map(|&i| upstreams[i].config.weight as u32)
|
||||
.sum();
|
||||
|
||||
if total_weight == 0 {
|
||||
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
|
||||
}
|
||||
|
||||
let mut choice = rand::thread_rng().gen_range(0..total_weight);
|
||||
|
||||
for &idx in &healthy_indices {
|
||||
let weight = upstreams[idx].config.weight as u32;
|
||||
if choice < weight {
|
||||
return Some(idx);
|
||||
}
|
||||
choice -= weight;
|
||||
}
|
||||
|
||||
Some(healthy_indices[0])
|
||||
}
|
||||
|
||||
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> {
|
||||
let idx = self.select_upstream().await
|
||||
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
|
||||
|
||||
let upstream = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard[idx].config.clone()
|
||||
};
|
||||
|
||||
match self.connect_via_upstream(&upstream, target).await {
|
||||
Ok(stream) => {
|
||||
// Mark success
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
}
|
||||
Ok(stream)
|
||||
},
|
||||
Err(e) => {
|
||||
// Mark failure
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
u.fails += 1;
|
||||
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
|
||||
if u.fails > 3 {
|
||||
u.healthy = false;
|
||||
warn!("Upstream disabled due to failures: {:?}", u.config);
|
||||
}
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
|
||||
match &config.upstream_type {
|
||||
UpstreamType::Direct { interface } => {
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&target.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&proxy_addr.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let mut stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
connect_socks4(&mut stream, target, user_id.as_deref()).await?;
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
// Non-blocking connect logic
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&proxy_addr.into()) {
|
||||
Ok(()) => {},
|
||||
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {},
|
||||
Err(err) => return Err(ProxyError::Io(err)),
|
||||
}
|
||||
|
||||
let std_stream: std::net::TcpStream = socket.into();
|
||||
let mut stream = TcpStream::from_std(std_stream)?;
|
||||
|
||||
// Wait for connection to complete
|
||||
stream.writable().await?;
|
||||
if let Some(e) = stream.take_error()? {
|
||||
return Err(ProxyError::Io(e));
|
||||
}
|
||||
|
||||
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
|
||||
Ok(stream)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Background task to check health
|
||||
pub async fn run_health_checks(&self) {
|
||||
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2)
|
||||
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
|
||||
let count = self.upstreams.read().await.len();
|
||||
for i in 0..count {
|
||||
let config = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard[i].config.clone()
|
||||
};
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.connect_via_upstream(&config, check_target)
|
||||
).await;
|
||||
|
||||
let mut guard = self.upstreams.write().await;
|
||||
let u = &mut guard[i];
|
||||
|
||||
match result {
|
||||
Ok(Ok(_stream)) => {
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Health check failed for {:?}: {}", u.config, e);
|
||||
// Don't mark unhealthy immediately in background check
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Health check timeout for {:?}", u.config);
|
||||
}
|
||||
}
|
||||
u.last_check = std::time::Instant::now();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//! IP Addr Detect
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, SocketAddr, UdpSocket};
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
@@ -40,31 +40,77 @@ const IPV6_URLS: &[&str] = &[
|
||||
"http://api6.ipify.org/",
|
||||
];
|
||||
|
||||
/// Detect local IP address by connecting to a public DNS
|
||||
/// This does not actually send any packets
|
||||
fn get_local_ip(target: &str) -> Option<IpAddr> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
|
||||
socket.connect(target).ok()?;
|
||||
socket.local_addr().ok().map(|addr| addr.ip())
|
||||
}
|
||||
|
||||
fn get_local_ipv6(target: &str) -> Option<IpAddr> {
|
||||
let socket = UdpSocket::bind("[::]:0").ok()?;
|
||||
socket.connect(target).ok()?;
|
||||
socket.local_addr().ok().map(|addr| addr.ip())
|
||||
}
|
||||
|
||||
/// Detect public IP addresses
|
||||
pub async fn detect_ip() -> IpInfo {
|
||||
let mut info = IpInfo::default();
|
||||
|
||||
// Detect IPv4
|
||||
// Try to get local interface IP first (default gateway interface)
|
||||
// We connect to Google DNS to find out which interface is used for routing
|
||||
if let Some(ip) = get_local_ip("8.8.8.8:80") {
|
||||
if ip.is_ipv4() && !ip.is_loopback() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected local IPv4 address via routing");
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ip) = get_local_ipv6("[2001:4860:4860::8888]:80") {
|
||||
if ip.is_ipv6() && !ip.is_loopback() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected local IPv6 address via routing");
|
||||
}
|
||||
}
|
||||
|
||||
// If local detection failed or returned private IP (and we want public),
|
||||
// or just as a fallback/verification, we might want to check external services.
|
||||
// However, the requirement is: "if IP for listening is not set... it should be IP from interface...
|
||||
// if impossible - request external resources".
|
||||
|
||||
// So if we found a local IP, we might be good. But often servers are behind NAT.
|
||||
// If the local IP is private, we probably want the public IP for the tg:// link.
|
||||
// Let's check if the detected IPs are private.
|
||||
|
||||
let need_external_v4 = info.ipv4.map_or(true, |ip| is_private_ip(ip));
|
||||
let need_external_v6 = info.ipv6.map_or(true, |ip| is_private_ip(ip));
|
||||
|
||||
if need_external_v4 {
|
||||
debug!("Local IPv4 is private or missing, checking external services...");
|
||||
for url in IPV4_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv4() {
|
||||
info.ipv4 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv4 address");
|
||||
debug!(ip = %ip, "Detected public IPv4 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Detect IPv6
|
||||
if need_external_v6 {
|
||||
debug!("Local IPv6 is private or missing, checking external services...");
|
||||
for url in IPV6_URLS {
|
||||
if let Some(ip) = fetch_ip(url).await {
|
||||
if ip.is_ipv6() {
|
||||
info.ipv6 = Some(ip);
|
||||
debug!(ip = %ip, "Detected IPv6 address");
|
||||
debug!(ip = %ip, "Detected public IPv6 address");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !info.has_any() {
|
||||
warn!("Failed to detect public IP address");
|
||||
@@ -73,6 +119,17 @@ pub async fn detect_ip() -> IpInfo {
|
||||
info
|
||||
}
|
||||
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local()
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
ipv6.is_loopback() || (ipv6.segments()[0] & 0xfe00) == 0xfc00 // Unique Local
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch IP from URL
|
||||
async fn fetch_ip(url: &str) -> Option<IpAddr> {
|
||||
let client = reqwest::Client::builder()
|
||||
|
||||
Reference in New Issue
Block a user