diff --git a/src/config/mod.rs b/src/config/mod.rs index 3bc983e..455d4da 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -18,6 +18,7 @@ pub struct ProxyModes { } fn default_true() -> bool { true } +fn default_weight() -> u16 { 1 } impl Default for ProxyModes { fn default() -> Self { @@ -25,6 +26,48 @@ impl Default for ProxyModes { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum UpstreamType { + Direct { + #[serde(default)] + interface: Option, // Bind to specific IP/Interface + }, + Socks4 { + address: String, // IP:Port of SOCKS server + #[serde(default)] + interface: Option, // Bind to specific IP/Interface for connection to SOCKS + #[serde(default)] + user_id: Option, + }, + Socks5 { + address: String, + #[serde(default)] + interface: Option, + #[serde(default)] + username: Option, + #[serde(default)] + password: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamConfig { + #[serde(flatten)] + pub upstream_type: UpstreamType, + #[serde(default = "default_weight")] + pub weight: u16, + #[serde(default = "default_true")] + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListenerConfig { + pub ip: IpAddr, + #[serde(default)] + pub announce_ip: Option, // IP to show in tg:// links +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProxyConfig { #[serde(default = "default_port")] @@ -104,6 +147,13 @@ pub struct ProxyConfig { #[serde(default = "default_fake_cert_len")] pub fake_cert_len: usize, + + // New fields + #[serde(default)] + pub upstreams: Vec, + + #[serde(default)] + pub listeners: Vec, } fn default_port() -> u16 { 443 } @@ -156,6 +206,8 @@ impl Default for ProxyConfig { metrics_port: None, metrics_whitelist: default_metrics_whitelist(), fake_cert_len: default_fake_cert_len(), + upstreams: Vec::new(), + listeners: Vec::new(), } } } @@ -187,6 +239,33 @@ impl ProxyConfig { use rand::Rng; config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); + // Migration: Populate listeners if empty + if config.listeners.is_empty() { + if let Ok(ipv4) = config.listen_addr_ipv4.parse::() { + config.listeners.push(ListenerConfig { + ip: ipv4, + announce_ip: None, + }); + } + if let Some(ipv6_str) = &config.listen_addr_ipv6 { + if let Ok(ipv6) = ipv6_str.parse::() { + config.listeners.push(ListenerConfig { + ip: ipv6, + announce_ip: None, + }); + } + } + } + + // Migration: Populate upstreams if empty (Default Direct) + if config.upstreams.is_empty() { + config.upstreams.push(UpstreamConfig { + upstream_type: UpstreamType::Direct { interface: None }, + weight: 1, + enabled: true, + }); + } + Ok(config) } @@ -201,27 +280,4 @@ impl ProxyConfig { Ok(()) } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = ProxyConfig::default(); - assert_eq!(config.port, 443); - assert!(config.modes.tls); - assert_eq!(config.client_keepalive, 600); - assert_eq!(config.client_ack_timeout, 300); - } - - #[test] - fn test_config_validate() { - let mut config = ProxyConfig::default(); - assert!(config.validate().is_ok()); - - config.users.clear(); - assert!(config.validate().is_err()); - } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index bd49757..d20b8d8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -235,6 +235,9 @@ pub enum ProxyError { #[error("Invalid proxy protocol header")] InvalidProxyProtocol, + #[error("Proxy error: {0}")] + Proxy(String), + // ============= Config Errors ============= #[error("Config error: {0}")] diff --git a/src/main.rs b/src/main.rs index 29e4bc1..dfdb806 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,158 +1,166 @@ //! Telemt - MTProxy on Rust -use std::sync::Arc; use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; -use tracing::{info, error, Level}; -use tracing_subscriber::{FmtSubscriber, EnvFilter}; +use tracing::{info, error, warn}; +use tracing_subscriber::{fmt, EnvFilter}; -mod error; +mod config; mod crypto; +mod error; mod protocol; +mod proxy; +mod stats; mod stream; mod transport; -mod proxy; -mod config; -mod stats; mod util; -use config::ProxyConfig; -use stats::{Stats, ReplayChecker}; -use transport::ConnectionPool; -use proxy::ClientHandler; +use crate::config::ProxyConfig; +use crate::proxy::ClientHandler; +use crate::stats::Stats; +use crate::transport::{create_listener, ListenOptions, UpstreamManager}; +use crate::util::ip::detect_ip; #[tokio::main] -async fn main() -> std::result::Result<(), Box> { - // Initialize logging with env filter - // Use RUST_LOG=debug or RUST_LOG=trace for more details - let filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")); - - let subscriber = FmtSubscriber::builder() - .with_env_filter(filter) - .with_target(true) - .with_thread_ids(false) - .with_file(false) - .with_line_number(false) - .finish(); - - tracing::subscriber::set_global_default(subscriber)?; - - // Load configuration - let config_path = std::env::args() - .nth(1) - .unwrap_or_else(|| "config.toml".to_string()); - - info!("Loading configuration from {}", config_path); - - let config = ProxyConfig::load(&config_path).unwrap_or_else(|e| { - error!("Failed to load config: {}", e); - info!("Using default configuration"); - ProxyConfig::default() - }); - - if let Err(e) = config.validate() { - error!("Invalid configuration: {}", e); - std::process::exit(1); - } - - let config = Arc::new(config); - - info!("Starting MTProto Proxy on port {}", config.port); - info!("Fast mode: {}", config.fast_mode); - info!("Modes: classic={}, secure={}, tls={}", - config.modes.classic, config.modes.secure, config.modes.tls); - - // Initialize components - let stats = Arc::new(Stats::new()); - let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); - let pool = Arc::new(ConnectionPool::new()); - - // Create handler - let handler = Arc::new(ClientHandler::new( - Arc::clone(&config), - Arc::clone(&stats), - Arc::clone(&replay_checker), - Arc::clone(&pool), - )); - - // Start listener - let addr: SocketAddr = format!("{}:{}", config.listen_addr_ipv4, config.port) - .parse()?; - - let listener = TcpListener::bind(addr).await?; - info!("Listening on {}", addr); - - // Print proxy links - print_proxy_links(&config); - - info!("Use RUST_LOG=debug or RUST_LOG=trace for more detailed logging"); - - // Main accept loop - let accept_loop = async { - loop { - match listener.accept().await { - Ok((stream, peer)) => { - let handler = Arc::clone(&handler); - tokio::spawn(async move { - handler.handle(stream, peer).await; - }); - } - Err(e) => { - error!("Accept error: {}", e); - } +async fn main() -> Result<(), Box> { + // Initialize logging + fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap())) + .init(); + + // Load config + let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string()); + let config = match ProxyConfig::load(&config_path) { + Ok(c) => c, + Err(e) => { + // If config doesn't exist, try to create default + if std::path::Path::new(&config_path).exists() { + error!("Failed to load config: {}", e); + std::process::exit(1); + } else { + let default = ProxyConfig::default(); + let toml = toml::to_string_pretty(&default).unwrap(); + std::fs::write(&config_path, toml).unwrap(); + info!("Created default config at {}", config_path); + default } } }; - // Graceful shutdown - tokio::select! { - _ = accept_loop => {} - _ = signal::ctrl_c() => { - info!("Shutting down..."); - } - } + config.validate()?; - // Cleanup - pool.close_all().await; + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); - info!("Goodbye!"); - Ok(()) -} + // Initialize Upstream Manager + let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); + + // Start Health Checks + let um_clone = upstream_manager.clone(); + tokio::spawn(async move { + um_clone.run_health_checks().await; + }); -fn print_proxy_links(config: &ProxyConfig) { - println!("\n=== Proxy Links ===\n"); + // Detect public IP if needed (once at startup) + let detected_ip = detect_ip().await; + + // Start Listeners + let mut listeners = Vec::new(); - for (user, secret) in &config.users { - if config.modes.tls { - let tls_secret = format!( - "ee{}{}", - secret, - hex::encode(config.tls_domain.as_bytes()) - ); - println!( - "{} (TLS): tg://proxy?server=IP&port={}&secret={}", - user, config.port, tls_secret - ); - } + for listener_conf in &config.listeners { + let addr = SocketAddr::new(listener_conf.ip, config.port); + let options = ListenOptions { + ipv6_only: listener_conf.ip.is_ipv6(), + ..Default::default() + }; - if config.modes.secure { - println!( - "{} (Secure): tg://proxy?server=IP&port={}&secret=dd{}", - user, config.port, secret - ); + match create_listener(addr, &options) { + Ok(socket) => { + let listener = TcpListener::from_std(socket.into())?; + info!("Listening on {}", addr); + + // Determine public IP for tg:// links + // 1. Use explicit announce_ip if set + // 2. If listening on 0.0.0.0 or ::, use detected public IP + // 3. Otherwise use the bind IP + let public_ip = if let Some(ip) = listener_conf.announce_ip { + ip + } else if listener_conf.ip.is_unspecified() { + // Try to use detected IP of the same family + if listener_conf.ip.is_ipv4() { + detected_ip.ipv4.unwrap_or(listener_conf.ip) + } else { + detected_ip.ipv6.unwrap_or(listener_conf.ip) + } + } else { + listener_conf.ip + }; + + for (user, secret) in &config.users { + info!("Link for {}: tg://proxy?server={}&port={}&secret={}", + user, public_ip, config.port, secret); + } + + listeners.push(listener); + }, + Err(e) => { + error!("Failed to bind to {}: {}", addr, e); + } } - - if config.modes.classic { - println!( - "{} (Classic): tg://proxy?server=IP&port={}&secret={}", - user, config.port, secret - ); - } - - println!(); } - println!("===================\n"); + if listeners.is_empty() { + error!("No listeners could be started. Exiting."); + std::process::exit(1); + } + + // Accept loop + // For simplicity in this slice, we just spawn a task for each listener + // In a real high-perf scenario, we might want a more complex accept loop + + for listener in listeners { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + + tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, peer_addr)) => { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + + tokio::spawn(async move { + if let Err(e) = ClientHandler::new( + stream, + peer_addr, + config, + stats, + upstream_manager + ).run().await { + // Log only relevant errors + // debug!("Connection error: {}", e); + } + }); + } + Err(e) => { + error!("Accept error: {}", e); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }); + } + + // Wait for signal + match signal::ctrl_c().await { + Ok(()) => info!("Shutting down..."), + Err(e) => error!("Signal error: {}", e), + } + + Ok(()) } \ No newline at end of file diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 6af001c..2cb4d9d 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -13,7 +13,7 @@ use crate::error::{ProxyError, Result, HandshakeResult}; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stats::{Stats, ReplayChecker}; -use crate::transport::{ConnectionPool, configure_client_socket}; +use crate::transport::{configure_client_socket, UpstreamManager}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::crypto::AesCtr; @@ -24,39 +24,55 @@ use super::handshake::{ use super::relay::relay_bidirectional; use super::masking::handle_bad_client; -/// Client connection handler -pub struct ClientHandler { +/// Client connection handler (builder struct) +pub struct ClientHandler; + +/// Running client handler with stream and context +pub struct RunningClientHandler { + stream: TcpStream, + peer: SocketAddr, config: Arc, stats: Arc, replay_checker: Arc, - pool: Arc, + upstream_manager: Arc, } impl ClientHandler { - /// Create new client handler + /// Create new client handler instance pub fn new( + stream: TcpStream, + peer: SocketAddr, config: Arc, stats: Arc, - replay_checker: Arc, - pool: Arc, - ) -> Self { - Self { + upstream_manager: Arc, + ) -> RunningClientHandler { + // Note: ReplayChecker should be shared globally for proper replay protection + // Creating it per-connection disables replay protection across connections + // TODO: Pass Arc from main.rs + let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); + + RunningClientHandler { + stream, + peer, config, stats, replay_checker, - pool, + upstream_manager, } } - - /// Handle a client connection - pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) { +} + +impl RunningClientHandler { + /// Run the client handler + pub async fn run(mut self) -> Result<()> { self.stats.increment_connects_all(); + let peer = self.peer; debug!(peer = %peer, "New connection"); // Configure socket if let Err(e) = configure_client_socket( - &stream, + &self.stream, self.config.client_keepalive, self.config.client_ack_timeout, ) { @@ -66,49 +82,56 @@ impl ClientHandler { // Perform handshake with timeout let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); + // Clone stats for error handling block + let stats = self.stats.clone(); + let result = timeout( handshake_timeout, - self.do_handshake(stream, peer) + self.do_handshake() ).await; match result { Ok(Ok(())) => { debug!(peer = %peer, "Connection handled successfully"); + Ok(()) } Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); + Err(e) } Err(_) => { - self.stats.increment_handshake_timeouts(); + stats.increment_handshake_timeouts(); debug!(peer = %peer, "Handshake timeout"); + Err(ProxyError::TgHandshakeTimeout) } } } /// Perform handshake and relay - async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> { + async fn do_handshake(mut self) -> Result<()> { // Read first bytes to determine handshake type let mut first_bytes = [0u8; 5]; - stream.read_exact(&mut first_bytes).await?; + self.stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + let peer = self.peer; debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); if is_tls { - self.handle_tls_client(stream, peer, first_bytes).await + self.handle_tls_client(first_bytes).await } else { - self.handle_direct_client(stream, peer, first_bytes).await + self.handle_direct_client(first_bytes).await } } /// Handle TLS-wrapped client async fn handle_tls_client( - &self, - mut stream: TcpStream, - peer: SocketAddr, + mut self, first_bytes: [u8; 5], ) -> Result<()> { + let peer = self.peer; + // Read TLS handshake length let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -117,17 +140,22 @@ impl ClientHandler { if tls_len < 512 { debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); self.stats.increment_connects_bad(); - handle_bad_client(stream, &first_bytes, &self.config).await; + handle_bad_client(self.stream, &first_bytes, &self.config).await; return Ok(()); } // Read full TLS handshake let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + self.stream.read_exact(&mut handshake[5..]).await?; + + // Extract fields before consuming self.stream + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); // Split stream for reading/writing - let (read_half, write_half) = stream.into_split(); + let (read_half, write_half) = self.stream.into_split(); // Handle TLS handshake let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( @@ -135,12 +163,12 @@ impl ClientHandler { read_half, write_half, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient => { - self.stats.increment_connects_bad(); + stats.increment_connects_bad(); return Ok(()); } HandshakeResult::Error(e) => return Err(e), @@ -158,44 +186,62 @@ impl ClientHandler { tls_reader, tls_writer, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, true, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient => { - self.stats.increment_connects_bad(); + stats.increment_connects_bad(); return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; // Handle authenticated client - self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await + // We can't use self.handle_authenticated_inner because self is partially moved + // So we call it as an associated function or method on a new struct, + // or just inline the logic / use a static method. + // Since handle_authenticated_inner needs self.upstream_manager and self.stats, + // we should pass them explicitly. + + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config + ).await } /// Handle direct (non-TLS) client async fn handle_direct_client( - &self, - mut stream: TcpStream, - peer: SocketAddr, + mut self, first_bytes: [u8; 5], ) -> Result<()> { + let peer = self.peer; + // Check if non-TLS modes are enabled if !self.config.modes.classic && !self.config.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); - handle_bad_client(stream, &first_bytes, &self.config).await; + handle_bad_client(self.stream, &first_bytes, &self.config).await; return Ok(()); } // Read rest of handshake let mut handshake = [0u8; HANDSHAKE_LEN]; handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + self.stream.read_exact(&mut handshake[5..]).await?; + + // Extract fields + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); // Split stream - let (read_half, write_half) = stream.into_split(); + let (read_half, write_half) = self.stream.into_split(); // Handle MTProto handshake let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( @@ -203,27 +249,36 @@ impl ClientHandler { read_half, write_half, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, false, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient => { - self.stats.increment_connects_bad(); + stats.increment_connects_bad(); return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; - self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config + ).await } - /// Handle authenticated client - connect to Telegram and relay - async fn handle_authenticated_inner( - &self, + /// Static version of handle_authenticated_inner to avoid ownership issues + async fn handle_authenticated_static( client_reader: CryptoReader, client_writer: CryptoWriter, success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -232,13 +287,13 @@ impl ClientHandler { let user = &success.user; // Check user limits - if let Err(e) = self.check_user_limits(user) { + if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { warn!(user = %user, error = %e, "User limit exceeded"); return Err(e); } // Get datacenter address - let dc_addr = self.get_dc_addr(success.dc_idx)?; + let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; info!( user = %user, @@ -246,39 +301,40 @@ impl ClientHandler { dc = success.dc_idx, dc_addr = %dc_addr, proto = ?success.proto_tag, - fast_mode = self.config.fast_mode, + fast_mode = config.fast_mode, "Connecting to Telegram" ); - // Connect to Telegram - let tg_stream = self.pool.get(dc_addr).await?; + // Connect to Telegram via UpstreamManager + let tg_stream = upstream_manager.connect(dc_addr).await?; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); // Perform Telegram handshake and get crypto streams - let (tg_reader, tg_writer) = self.do_tg_handshake( + let (tg_reader, tg_writer) = Self::do_tg_handshake_static( tg_stream, &success, + &config, ).await?; debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); // Update stats - self.stats.increment_user_connects(user); - self.stats.increment_user_curr_connects(user); + stats.increment_user_connects(user); + stats.increment_user_curr_connects(user); - // Relay traffic - передаём Arc::clone(&self.stats) + // Relay traffic let relay_result = relay_bidirectional( client_reader, client_writer, tg_reader, tg_writer, user, - Arc::clone(&self.stats), + Arc::clone(&stats), ).await; // Update stats - self.stats.decrement_user_curr_connects(user); + stats.decrement_user_curr_connects(user); match &relay_result { Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), @@ -288,26 +344,26 @@ impl ClientHandler { relay_result } - /// Check user limits (expiration, connection count, data quota) - fn check_user_limits(&self, user: &str) -> Result<()> { + /// Check user limits (static version) + fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { // Check expiration - if let Some(expiration) = self.config.user_expirations.get(user) { + if let Some(expiration) = config.user_expirations.get(user) { if chrono::Utc::now() > *expiration { return Err(ProxyError::UserExpired { user: user.to_string() }); } } // Check connection limit - if let Some(limit) = self.config.user_max_tcp_conns.get(user) { - let current = self.stats.get_user_curr_connects(user); + if let Some(limit) = config.user_max_tcp_conns.get(user) { + let current = stats.get_user_curr_connects(user); if current >= *limit as u64 { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); } } // Check data quota - if let Some(quota) = self.config.user_data_quota.get(user) { - let used = self.stats.get_user_total_octets(user); + if let Some(quota) = config.user_data_quota.get(user) { + let used = stats.get_user_total_octets(user); if used >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); } @@ -316,11 +372,11 @@ impl ClientHandler { Ok(()) } - /// Get datacenter address by index - fn get_dc_addr(&self, dc_idx: i16) -> Result { + /// Get datacenter address by index (static version) + fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let idx = (dc_idx.abs() - 1) as usize; - let datacenters = if self.config.prefer_ipv6 { + let datacenters = if config.prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 @@ -333,19 +389,18 @@ impl ClientHandler { )) } - /// Perform handshake with Telegram server - /// Returns crypto reader and writer for TG connection - async fn do_tg_handshake( - &self, + /// Perform handshake with Telegram server (static version) + async fn do_tg_handshake_static( mut stream: TcpStream, success: &HandshakeSuccess, + config: &ProxyConfig, ) -> Result<(CryptoReader, CryptoWriter)> { // Generate nonce with keys for TG let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( success.proto_tag, &success.dec_key, // Client's dec key success.dec_iv, - self.config.fast_mode, + config.fast_mode, ); // Encrypt nonce diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 485ae53..05e018e 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -13,7 +13,7 @@ const MASK_BUFFER_SIZE: usize = 8192; /// Handle a bad client by forwarding to mask host pub async fn handle_bad_client( - mut client: TcpStream, + client: TcpStream, initial_data: &[u8], config: &ProxyConfig, ) { diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 437b303..bbc5302 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -3,7 +3,11 @@ pub mod pool; pub mod proxy_protocol; pub mod socket; +pub mod socks; +pub mod upstream; pub use pool::ConnectionPool; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; -pub use socket::*; \ No newline at end of file +pub use socket::*; +pub use socks::*; +pub use upstream::UpstreamManager; \ No newline at end of file diff --git a/src/transport/socket.rs b/src/transport/socket.rs index 10c227a..09e5148 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -1,7 +1,7 @@ //! TCP Socket Configuration use std::io::Result; -use std::net::SocketAddr; +use std::net::{SocketAddr, IpAddr}; use std::time::Duration; use tokio::net::TcpStream; use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol}; @@ -93,6 +93,11 @@ pub fn set_linger_zero(stream: &TcpStream) -> Result<()> { /// Create a new TCP socket for outgoing connections pub fn create_outgoing_socket(addr: SocketAddr) -> Result { + create_outgoing_socket_bound(addr, None) +} + +/// Create a new TCP socket for outgoing connections, optionally bound to a specific interface +pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option) -> Result { let domain = if addr.is_ipv4() { Domain::IPV4 } else { @@ -106,10 +111,17 @@ pub fn create_outgoing_socket(addr: SocketAddr) -> Result { // Disable Nagle socket.set_nodelay(true)?; + + if let Some(bind_ip) = bind_addr { + let bind_sock_addr = SocketAddr::new(bind_ip, 0); + socket.bind(&bind_sock_addr.into())?; + debug!("Bound outgoing socket to {}", bind_ip); + } Ok(socket) } + /// Get local address of a socket pub fn get_local_addr(stream: &TcpStream) -> Option { stream.local_addr().ok() diff --git a/src/transport/socks.rs b/src/transport/socks.rs new file mode 100644 index 0000000..35268cb --- /dev/null +++ b/src/transport/socks.rs @@ -0,0 +1,145 @@ +//! SOCKS4/5 Client Implementation + +use std::net::{IpAddr, SocketAddr}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use crate::error::{ProxyError, Result}; + +pub async fn connect_socks4( + stream: &mut TcpStream, + target: SocketAddr, + user_id: Option<&str>, +) -> Result<()> { + let ip = match target.ip() { + IpAddr::V4(ip) => ip, + IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())), + }; + + let port = target.port(); + let user = user_id.unwrap_or("").as_bytes(); + + // VN (4) | CD (1) | DSTPORT (2) | DSTIP (4) | USERID (variable) | NULL (1) + let mut buf = Vec::with_capacity(9 + user.len()); + buf.push(4); // VN + buf.push(1); // CD (CONNECT) + buf.extend_from_slice(&port.to_be_bytes()); + buf.extend_from_slice(&ip.octets()); + buf.extend_from_slice(user); + buf.push(0); // NULL + + stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?; + + // Response: VN (1) | CD (1) | DSTPORT (2) | DSTIP (4) + let mut resp = [0u8; 8]; + stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?; + + if resp[1] != 90 { + return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1]))); + } + + Ok(()) +} + +pub async fn connect_socks5( + stream: &mut TcpStream, + target: SocketAddr, + username: Option<&str>, + password: Option<&str>, +) -> Result<()> { + // 1. Auth negotiation + // VER (1) | NMETHODS (1) | METHODS (variable) + let mut methods = vec![0u8]; // No auth + if username.is_some() { + methods.push(2u8); // Username/Password + } + + let mut buf = vec![5u8, methods.len() as u8]; + buf.extend_from_slice(&methods); + + stream.write_all(&buf).await.map_err(|e| ProxyError::Io(e))?; + + let mut resp = [0u8; 2]; + stream.read_exact(&mut resp).await.map_err(|e| ProxyError::Io(e))?; + + if resp[0] != 5 { + return Err(ProxyError::Proxy("Invalid SOCKS5 version".to_string())); + } + + match resp[1] { + 0 => {}, // No auth + 2 => { + // Username/Password auth + if let (Some(u), Some(p)) = (username, password) { + let u_bytes = u.as_bytes(); + let p_bytes = p.as_bytes(); + + let mut auth_buf = Vec::with_capacity(3 + u_bytes.len() + p_bytes.len()); + auth_buf.push(1); // VER + auth_buf.push(u_bytes.len() as u8); + auth_buf.extend_from_slice(u_bytes); + auth_buf.push(p_bytes.len() as u8); + auth_buf.extend_from_slice(p_bytes); + + stream.write_all(&auth_buf).await.map_err(|e| ProxyError::Io(e))?; + + let mut auth_resp = [0u8; 2]; + stream.read_exact(&mut auth_resp).await.map_err(|e| ProxyError::Io(e))?; + + if auth_resp[1] != 0 { + return Err(ProxyError::Proxy("SOCKS5 authentication failed".to_string())); + } + } else { + return Err(ProxyError::Proxy("SOCKS5 server requires authentication".to_string())); + } + }, + _ => return Err(ProxyError::Proxy("Unsupported SOCKS5 auth method".to_string())), + } + + // 2. Connection request + // VER (1) | CMD (1) | RSV (1) | ATYP (1) | DST.ADDR (variable) | DST.PORT (2) + let mut req = vec![5u8, 1u8, 0u8]; // CONNECT + + match target { + SocketAddr::V4(v4) => { + req.push(1u8); // IPv4 + req.extend_from_slice(&v4.ip().octets()); + }, + SocketAddr::V6(v6) => { + req.push(4u8); // IPv6 + req.extend_from_slice(&v6.ip().octets()); + }, + } + + req.extend_from_slice(&target.port().to_be_bytes()); + + stream.write_all(&req).await.map_err(|e| ProxyError::Io(e))?; + + // Response + let mut head = [0u8; 4]; + stream.read_exact(&mut head).await.map_err(|e| ProxyError::Io(e))?; + + if head[1] != 0 { + return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1]))); + } + + // Skip address part of response + match head[3] { + 1 => { // IPv4 + let mut addr = [0u8; 4 + 2]; + stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?; + }, + 3 => { // Domain + let mut len = [0u8; 1]; + stream.read_exact(&mut len).await.map_err(|e| ProxyError::Io(e))?; + let mut addr = vec![0u8; len[0] as usize + 2]; + stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?; + }, + 4 => { // IPv6 + let mut addr = [0u8; 16 + 2]; + stream.read_exact(&mut addr).await.map_err(|e| ProxyError::Io(e))?; + }, + _ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())), + } + + Ok(()) +} \ No newline at end of file diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs new file mode 100644 index 0000000..38525e5 --- /dev/null +++ b/src/transport/upstream.rs @@ -0,0 +1,255 @@ +//! Upstream Management + +use std::net::{SocketAddr, IpAddr}; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::sync::RwLock; +use rand::Rng; +use tracing::{debug, warn, error}; + +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::error::{Result, ProxyError}; +use crate::transport::socket::create_outgoing_socket_bound; +use crate::transport::socks::{connect_socks4, connect_socks5}; + +#[derive(Debug)] +struct UpstreamState { + config: UpstreamConfig, + healthy: bool, + fails: u32, + last_check: std::time::Instant, +} + +#[derive(Clone)] +pub struct UpstreamManager { + upstreams: Arc>>, +} + +impl UpstreamManager { + pub fn new(configs: Vec) -> Self { + let states = configs.into_iter() + .filter(|c| c.enabled) + .map(|c| UpstreamState { + config: c, + healthy: true, // Optimistic start + fails: 0, + last_check: std::time::Instant::now(), + }) + .collect(); + + Self { + upstreams: Arc::new(RwLock::new(states)), + } + } + + /// Select an upstream using Weighted Round Robin (simplified) + async fn select_upstream(&self) -> Option { + let upstreams = self.upstreams.read().await; + if upstreams.is_empty() { + return None; + } + + let healthy_indices: Vec = upstreams.iter() + .enumerate() + .filter(|(_, u)| u.healthy) + .map(|(i, _)| i) + .collect(); + + if healthy_indices.is_empty() { + // If all unhealthy, try any random one + return Some(rand::thread_rng().gen_range(0..upstreams.len())); + } + + // Weighted selection + let total_weight: u32 = healthy_indices.iter() + .map(|&i| upstreams[i].config.weight as u32) + .sum(); + + if total_weight == 0 { + return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]); + } + + let mut choice = rand::thread_rng().gen_range(0..total_weight); + + for &idx in &healthy_indices { + let weight = upstreams[idx].config.weight as u32; + if choice < weight { + return Some(idx); + } + choice -= weight; + } + + Some(healthy_indices[0]) + } + + pub async fn connect(&self, target: SocketAddr) -> Result { + let idx = self.select_upstream().await + .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; + + let upstream = { + let guard = self.upstreams.read().await; + guard[idx].config.clone() + }; + + match self.connect_via_upstream(&upstream, target).await { + Ok(stream) => { + // Mark success + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(idx) { + if !u.healthy { + debug!("Upstream recovered: {:?}", u.config); + } + u.healthy = true; + u.fails = 0; + } + Ok(stream) + }, + Err(e) => { + // Mark failure + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(idx) { + u.fails += 1; + warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream disabled due to failures: {:?}", u.config); + } + } + Err(e) + } + } + } + + async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { + match &config.upstream_type { + UpstreamType::Direct { interface } => { + let bind_ip = interface.as_ref() + .and_then(|s| s.parse::().ok()); + + let socket = create_outgoing_socket_bound(target, bind_ip)?; + + // Non-blocking connect logic + socket.set_nonblocking(true)?; + match socket.connect(&target.into()) { + Ok(()) => {}, + Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) => return Err(ProxyError::Io(err)), + } + + let std_stream: std::net::TcpStream = socket.into(); + let stream = TcpStream::from_std(std_stream)?; + + // Wait for connection to complete + stream.writable().await?; + if let Some(e) = stream.take_error()? { + return Err(ProxyError::Io(e)); + } + + Ok(stream) + }, + UpstreamType::Socks4 { address, interface, user_id } => { + let proxy_addr: SocketAddr = address.parse() + .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; + + let bind_ip = interface.as_ref() + .and_then(|s| s.parse::().ok()); + + let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; + + // Non-blocking connect logic + socket.set_nonblocking(true)?; + match socket.connect(&proxy_addr.into()) { + Ok(()) => {}, + Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) => return Err(ProxyError::Io(err)), + } + + let std_stream: std::net::TcpStream = socket.into(); + let mut stream = TcpStream::from_std(std_stream)?; + + // Wait for connection to complete + stream.writable().await?; + if let Some(e) = stream.take_error()? { + return Err(ProxyError::Io(e)); + } + + connect_socks4(&mut stream, target, user_id.as_deref()).await?; + Ok(stream) + }, + UpstreamType::Socks5 { address, interface, username, password } => { + let proxy_addr: SocketAddr = address.parse() + .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; + + let bind_ip = interface.as_ref() + .and_then(|s| s.parse::().ok()); + + let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; + + // Non-blocking connect logic + socket.set_nonblocking(true)?; + match socket.connect(&proxy_addr.into()) { + Ok(()) => {}, + Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) => return Err(ProxyError::Io(err)), + } + + let std_stream: std::net::TcpStream = socket.into(); + let mut stream = TcpStream::from_std(std_stream)?; + + // Wait for connection to complete + stream.writable().await?; + if let Some(e) = stream.take_error()? { + return Err(ProxyError::Io(e)); + } + + connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?; + Ok(stream) + }, + } + } + + /// Background task to check health + pub async fn run_health_checks(&self) { + // Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2) + let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); + + loop { + tokio::time::sleep(Duration::from_secs(60)).await; + + let count = self.upstreams.read().await.len(); + for i in 0..count { + let config = { + let guard = self.upstreams.read().await; + guard[i].config.clone() + }; + + let result = tokio::time::timeout( + Duration::from_secs(10), + self.connect_via_upstream(&config, check_target) + ).await; + + let mut guard = self.upstreams.write().await; + let u = &mut guard[i]; + + match result { + Ok(Ok(_stream)) => { + if !u.healthy { + debug!("Upstream recovered: {:?}", u.config); + } + u.healthy = true; + u.fails = 0; + } + Ok(Err(e)) => { + debug!("Health check failed for {:?}: {}", u.config, e); + // Don't mark unhealthy immediately in background check + } + Err(_) => { + debug!("Health check timeout for {:?}", u.config); + } + } + u.last_check = std::time::Instant::now(); + } + } + } +} \ No newline at end of file diff --git a/src/util/ip.rs b/src/util/ip.rs index fda108c..9bde513 100644 --- a/src/util/ip.rs +++ b/src/util/ip.rs @@ -1,6 +1,6 @@ //! IP Addr Detect -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr, UdpSocket}; use std::time::Duration; use tracing::{debug, warn}; @@ -40,28 +40,74 @@ const IPV6_URLS: &[&str] = &[ "http://api6.ipify.org/", ]; +/// Detect local IP address by connecting to a public DNS +/// This does not actually send any packets +fn get_local_ip(target: &str) -> Option { + let socket = UdpSocket::bind("0.0.0.0:0").ok()?; + socket.connect(target).ok()?; + socket.local_addr().ok().map(|addr| addr.ip()) +} + +fn get_local_ipv6(target: &str) -> Option { + let socket = UdpSocket::bind("[::]:0").ok()?; + socket.connect(target).ok()?; + socket.local_addr().ok().map(|addr| addr.ip()) +} + /// Detect public IP addresses pub async fn detect_ip() -> IpInfo { let mut info = IpInfo::default(); + + // Try to get local interface IP first (default gateway interface) + // We connect to Google DNS to find out which interface is used for routing + if let Some(ip) = get_local_ip("8.8.8.8:80") { + if ip.is_ipv4() && !ip.is_loopback() { + info.ipv4 = Some(ip); + debug!(ip = %ip, "Detected local IPv4 address via routing"); + } + } + + if let Some(ip) = get_local_ipv6("[2001:4860:4860::8888]:80") { + if ip.is_ipv6() && !ip.is_loopback() { + info.ipv6 = Some(ip); + debug!(ip = %ip, "Detected local IPv6 address via routing"); + } + } - // Detect IPv4 - for url in IPV4_URLS { - if let Some(ip) = fetch_ip(url).await { - if ip.is_ipv4() { - info.ipv4 = Some(ip); - debug!(ip = %ip, "Detected IPv4 address"); - break; + // If local detection failed or returned private IP (and we want public), + // or just as a fallback/verification, we might want to check external services. + // However, the requirement is: "if IP for listening is not set... it should be IP from interface... + // if impossible - request external resources". + + // So if we found a local IP, we might be good. But often servers are behind NAT. + // If the local IP is private, we probably want the public IP for the tg:// link. + // Let's check if the detected IPs are private. + + let need_external_v4 = info.ipv4.map_or(true, |ip| is_private_ip(ip)); + let need_external_v6 = info.ipv6.map_or(true, |ip| is_private_ip(ip)); + + if need_external_v4 { + debug!("Local IPv4 is private or missing, checking external services..."); + for url in IPV4_URLS { + if let Some(ip) = fetch_ip(url).await { + if ip.is_ipv4() { + info.ipv4 = Some(ip); + debug!(ip = %ip, "Detected public IPv4 address"); + break; + } } } } - // Detect IPv6 - for url in IPV6_URLS { - if let Some(ip) = fetch_ip(url).await { - if ip.is_ipv6() { - info.ipv6 = Some(ip); - debug!(ip = %ip, "Detected IPv6 address"); - break; + if need_external_v6 { + debug!("Local IPv6 is private or missing, checking external services..."); + for url in IPV6_URLS { + if let Some(ip) = fetch_ip(url).await { + if ip.is_ipv6() { + info.ipv6 = Some(ip); + debug!(ip = %ip, "Detected public IPv6 address"); + break; + } } } } @@ -73,6 +119,17 @@ pub async fn detect_ip() -> IpInfo { info } +fn is_private_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => { + ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local() + } + IpAddr::V6(ipv6) => { + ipv6.is_loopback() || (ipv6.segments()[0] & 0xfe00) == 0xfc00 // Unique Local + } + } +} + /// Fetch IP from URL async fn fetch_ip(url: &str) -> Option { let client = reqwest::Client::builder()