Unix socket listener + reverse proxy improvements

This commit is contained in:
Жора Змейкин
2026-02-14 02:11:13 +03:00
parent 4b5270137b
commit 572e07a7fd
9 changed files with 487 additions and 69 deletions

2
Cargo.lock generated
View File

@@ -1723,7 +1723,7 @@ dependencies = [
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "1.2.0" version = "2.0.0"
dependencies = [ dependencies = [
"aes", "aes",
"base64", "base64",

View File

@@ -164,10 +164,6 @@ then Ctrl+X -> Y -> Enter to save
## Configuration ## Configuration
### Minimal Configuration for First Start ### Minimal Configuration for First Start
```toml ```toml
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
# === General Settings === # === General Settings ===
[general] [general]
prefer_ipv6 = false prefer_ipv6 = false
@@ -185,9 +181,17 @@ tls = true
port = 443 port = 443
listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::" listen_addr_ipv6 = "::"
# listen_unix_sock = "/var/run/telemt.sock" # Unix socket
# listen_unix_sock_perm = "0666" # Socket file permissions
# metrics_port = 9090 # metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"] # metrics_whitelist = ["127.0.0.1", "::1"]
# Users to show in the startup log (tg:// links)
[general.links]
show = ["hello"]
# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links
# public_port = 443 # Port for tg:// links (default: server.port)
# Listen on multiple interfaces/IPs (overrides listen_addr_*) # Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]] [[server.listeners]]
ip = "0.0.0.0" ip = "0.0.0.0"

View File

@@ -1,7 +1,3 @@
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
# === General Settings === # === General Settings ===
[general] [general]
prefer_ipv6 = true prefer_ipv6 = true
@@ -24,9 +20,17 @@ tls = true
port = 443 port = 443
listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::" listen_addr_ipv6 = "::"
# listen_unix_sock = "/var/run/telemt.sock" # Unix socket
# listen_unix_sock_perm = "0666" # Socket file permissions
# metrics_port = 9090 # metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"] # metrics_whitelist = ["127.0.0.1", "::1"]
# Users to show in the startup log (tg:// links)
[general.links]
show = ["hello"]
# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links
# public_port = 443 # Port for tg:// links (default: server.port)
# Listen on multiple interfaces/IPs (overrides listen_addr_*) # Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]] [[server.listeners]]
ip = "0.0.0.0" ip = "0.0.0.0"

View File

@@ -186,8 +186,6 @@ fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> Str
r#"# Telemt MTProxy — auto-generated config r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate # Re-run `telemt --init` to regenerate
show_link = ["{username}"]
[general] [general]
prefer_ipv6 = false prefer_ipv6 = false
fast_mode = true fast_mode = true
@@ -199,10 +197,17 @@ classic = false
secure = false secure = false
tls = true tls = true
[general.links]
show = ["{username}"]
# public_host = "proxy.example.com"
# public_port = 443
[server] [server]
port = {port} port = {port}
listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::" listen_addr_ipv6 = "::"
# listen_unix_sock = "/var/run/telemt.sock"
# listen_unix_sock_perm = "0666"
[[server.listeners]] [[server.listeners]]
ip = "0.0.0.0" ip = "0.0.0.0"
@@ -220,6 +225,8 @@ client_ack = 300
tls_domain = "{domain}" tls_domain = "{domain}"
mask = true mask = true
mask_port = 443 mask_port = 443
# mask_host = "{domain}"
# mask_unix_sock = "/var/run/nginx.sock"
fake_cert_len = 2048 fake_cert_len = 2048
[access] [access]

View File

@@ -4,7 +4,7 @@ use crate::error::{ProxyError, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::IpAddr;
use std::path::Path; use std::path::Path;
// ============= Helper Defaults ============= // ============= Helper Defaults =============
@@ -39,9 +39,6 @@ fn default_keepalive() -> u64 {
fn default_ack_timeout() -> u64 { fn default_ack_timeout() -> u64 {
300 300
} }
fn default_listen_addr() -> String {
"0.0.0.0".to_string()
}
fn default_fake_cert_len() -> usize { fn default_fake_cert_len() -> usize {
2048 2048
} }
@@ -164,6 +161,26 @@ pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub log_level: LogLevel, pub log_level: LogLevel,
#[serde(default)]
pub links: LinksConfig,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LinksConfig {
/// Users whose tg:// links to show at startup.
#[serde(default)]
pub show: Vec<String>,
/// Public host (IP or domain) for tg:// link generation.
/// Overrides announce_ip / detected IP in links.
#[serde(default)]
pub public_host: Option<String>,
/// Public port for tg:// link generation.
/// Overrides server.port in links.
#[serde(default)]
pub public_port: Option<u16>,
} }
impl Default for GeneralConfig { impl Default for GeneralConfig {
@@ -179,6 +196,7 @@ impl Default for GeneralConfig {
middle_proxy_nat_probe: false, middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None, middle_proxy_nat_stun: None,
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
links: LinksConfig::default(),
} }
} }
} }
@@ -188,8 +206,8 @@ pub struct ServerConfig {
#[serde(default = "default_port")] #[serde(default = "default_port")]
pub port: u16, pub port: u16,
#[serde(default = "default_listen_addr")] #[serde(default)]
pub listen_addr_ipv4: String, pub listen_addr_ipv4: Option<String>,
#[serde(default)] #[serde(default)]
pub listen_addr_ipv6: Option<String>, pub listen_addr_ipv6: Option<String>,
@@ -197,6 +215,11 @@ pub struct ServerConfig {
#[serde(default)] #[serde(default)]
pub listen_unix_sock: Option<String>, pub listen_unix_sock: Option<String>,
/// Unix socket file permissions (octal string, e.g. "0666").
/// Applied after bind. If not set, inherits from process umask.
#[serde(default)]
pub listen_unix_sock_perm: Option<String>,
#[serde(default)] #[serde(default)]
pub metrics_port: Option<u16>, pub metrics_port: Option<u16>,
@@ -211,9 +234,10 @@ impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
port: default_port(), port: default_port(),
listen_addr_ipv4: default_listen_addr(), listen_addr_ipv4: None,
listen_addr_ipv6: Some("::".to_string()), listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None, listen_unix_sock: None,
listen_unix_sock_perm: None,
metrics_port: None, metrics_port: None,
metrics_whitelist: default_metrics_whitelist(), metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(), listeners: Vec::new(),
@@ -502,15 +526,26 @@ pub struct ProxyConfig {
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
#[serde(default)] #[serde(default)]
pub default_dc: Option<u8>, pub default_dc: Option<u8>,
/// Non-fatal warnings collected during config loading.
#[serde(skip)]
pub warnings: Vec<String>,
} }
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = let content = std::fs::read_to_string(path)
std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; .map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig = // Pre-parse raw TOML to detect defaulted fields
toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; let raw: toml::Value = toml::from_str(&content)
.map_err(|e| ProxyError::Config(e.to_string()))?;
let port_explicit = raw.get("server")
.and_then(|s| s.get("port"))
.is_some();
let mut config: ProxyConfig = toml::from_str(&content)
.map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.access.users { for (user, secret) in &config.access.users {
@@ -562,15 +597,51 @@ impl ProxyConfig {
use rand::Rng; use rand::Rng;
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Validate listen_unix_sock
if config.server.listeners.is_empty() { if let Some(ref sock_path) = config.server.listen_unix_sock {
if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() { if sock_path.is_empty() {
config.server.listeners.push(ListenerConfig { return Err(ProxyError::Config(
ip: ipv4, "listen_unix_sock cannot be empty".to_string()
announce_ip: None, ));
});
} }
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { #[cfg(unix)]
if sock_path.len() > 107 {
return Err(ProxyError::Config(
format!("listen_unix_sock path too long: {} bytes (max 107)", sock_path.len())
));
}
#[cfg(not(unix))]
return Err(ProxyError::Config(
"listen_unix_sock is only supported on Unix platforms".to_string()
));
}
// Validate listen_unix_sock_perm
if let Some(ref perm_str) = config.server.listen_unix_sock_perm {
if config.server.listen_unix_sock.is_none() {
return Err(ProxyError::Config(
"listen_unix_sock_perm requires listen_unix_sock to be set".to_string()
));
}
u32::from_str_radix(perm_str, 8).map_err(|_| {
ProxyError::Config(format!(
"listen_unix_sock_perm must be an octal string (e.g. \"0666\"), got \"{}\"",
perm_str
))
})?;
}
// Migration: Populate listeners from legacy listen_addr_* fields.
if config.server.listeners.is_empty() {
if let Some(ref ipv4_str) = config.server.listen_addr_ipv4 {
if let Ok(ipv4) = ipv4_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig {
ip: ipv4,
announce_ip: None,
});
}
}
if let Some(ref ipv6_str) = config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
@@ -580,6 +651,20 @@ impl ProxyConfig {
} }
} }
// Validate: at least one listen endpoint must be configured.
if config.server.listeners.is_empty() && config.server.listen_unix_sock.is_none() {
return Err(ProxyError::Config(
"No listen address configured. Set [[server.listeners]], listen_addr_ipv4, or listen_unix_sock".to_string()
));
}
// Migration: show_link → general.links.show
if !config.show_link.is_empty() && config.general.links.show.is_empty() {
let migrated = config.show_link.resolve_users(&config.access.users)
.into_iter().cloned().collect::<Vec<_>>();
config.general.links.show = migrated;
}
// Migration: Populate upstreams if empty (Default Direct) // Migration: Populate upstreams if empty (Default Direct)
if config.upstreams.is_empty() { if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig { config.upstreams.push(UpstreamConfig {
@@ -589,6 +674,20 @@ impl ProxyConfig {
}); });
} }
// Warnings for defaulted fields
if !config.server.listeners.is_empty() && !port_explicit {
config.warnings.push(format!(
"[server] port is not set; defaulting to {}",
config.server.port
));
}
if config.server.listen_unix_sock.is_some() && config.general.links.public_port.is_none() {
config.warnings.push(format!(
"[general.links] public_port is not set; using [server] port {} for tg:// links",
config.server.port
));
}
Ok(config) Ok(config)
} }

View File

@@ -4,6 +4,8 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
@@ -20,9 +22,11 @@ mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::{LogLevel, ProxyConfig}; use crate::config::{ProxyConfig, LogLevel};
use crate::proxy::{ClientHandler, handle_client_stream};
#[cfg(unix)]
use crate::transport::{create_unix_listener, cleanup_unix_socket};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::proxy::ClientHandler;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
@@ -97,6 +101,31 @@ fn parse_cli() -> (String, bool, Option<String>) {
(config_path, silent, log_level) (config_path, silent, log_level)
} }
fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!("--- Proxy Links ({}) ---", host);
for user_name in &config.general.links.show {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
host, port, secret);
}
if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
host, port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
host, port, secret, domain_hex);
}
} else {
warn!("User '{}' listed in [general.links] show not found in [access.users]", user_name);
}
}
info!("------------------------");
}
#[tokio::main] #[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (config_path, cli_silent, cli_log_level) = parse_cli(); let (config_path, cli_silent, cli_log_level) = parse_cli();
@@ -168,6 +197,10 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
warn!("Using default tls_domain. Consider setting a custom domain."); warn!("Using default tls_domain. Consider setting a custom domain.");
} }
for w in &config.warnings {
warn!("{}", w);
}
let prefer_ipv6 = config.general.prefer_ipv6; let prefer_ipv6 = config.general.prefer_ipv6;
let use_middle_proxy = config.general.use_middle_proxy; let use_middle_proxy = config.general.use_middle_proxy;
let config = Arc::new(config); let config = Arc::new(config);
@@ -396,35 +429,12 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
listener_conf.ip listener_conf.ip
}; };
if !config.show_link.is_empty() { // Per-listener links (only when public_host is NOT set)
info!("--- Proxy Links ({}) ---", public_ip); let links = &config.general.links;
for user_name in config.show_link.resolve_users(&config.access.users) { if links.public_host.is_none() && !links.show.is_empty() {
if let Some(secret) = config.access.users.get(user_name) { let link_host = public_ip.to_string();
info!("User: {}", user_name); let link_port = links.public_port.unwrap_or(config.server.port);
if config.general.modes.classic { print_proxy_links(&link_host, link_port, &config);
info!(
" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret
);
}
if config.general.modes.secure {
info!(
" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret
);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex
);
}
} else {
warn!("User '{}' in show_link not found", user_name);
}
}
info!("------------------------");
} }
listeners.push(listener); listeners.push(listener);
@@ -435,9 +445,109 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
} }
// Unix socket listener
#[cfg(unix)]
let unix_sock_path = if let Some(ref unix_path) = config.server.listen_unix_sock {
match create_unix_listener(unix_path) {
Ok(std_listener) => {
// Set socket file permissions if configured
if let Some(ref perm_str) = config.server.listen_unix_sock_perm {
if let Ok(mode) = u32::from_str_radix(perm_str, 8) {
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(
unix_path,
std::fs::Permissions::from_mode(mode),
)?;
}
}
let unix_listener = UnixListener::from_std(std_listener)?;
info!("Listening on unix:{}", unix_path);
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
let unix_conn_counter = std::sync::Arc::new(
std::sync::atomic::AtomicU64::new(1)
);
tokio::spawn(async move {
loop {
match unix_listener.accept().await {
Ok((stream, _unix_addr)) => {
let conn_id = unix_conn_counter.fetch_add(
1, std::sync::atomic::Ordering::Relaxed
);
let fake_peer = SocketAddr::from(([127, 0, 0, 1], conn_id as u16));
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
tokio::spawn(async move {
if let Err(e) = handle_client_stream(
stream, fake_peer, config, stats,
upstream_manager, replay_checker, buffer_pool, rng,
me_pool,
).await {
debug!(error = %e, "Unix socket connection error");
}
});
}
Err(e) => {
error!("Unix socket accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
});
Some(unix_path.clone())
}
Err(e) => {
error!("Failed to bind to unix:{}: {}", unix_path, e);
None
}
}
} else {
None
};
// Links with explicit public_host (independent of TCP listeners)
let links = &config.general.links;
if let Some(ref public_host) = links.public_host {
if !links.show.is_empty() {
let link_port = links.public_port.unwrap_or(config.server.port);
print_proxy_links(public_host, link_port, &config);
}
}
// Warn if links were configured but couldn't be shown
// (no TCP listeners succeeded and no public_host set)
let links = &config.general.links;
if listeners.is_empty() && links.public_host.is_none() && !links.show.is_empty() {
warn!("Proxy links not shown: no TCP listeners bound. Set [general.links] public_host or fix listener errors above.");
}
if listeners.is_empty() { if listeners.is_empty() {
error!("No listeners. Exiting."); #[cfg(unix)]
std::process::exit(1); if unix_sock_path.is_none() {
error!("No listeners. Exiting.");
std::process::exit(1);
}
#[cfg(not(unix))]
{
error!("No listeners. Exiting.");
std::process::exit(1);
}
} }
// Switch to user-configured log level after startup // Switch to user-configured log level after startup
@@ -500,7 +610,13 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
match signal::ctrl_c().await { match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."), Ok(()) => {
info!("Shutting down...");
#[cfg(unix)]
if let Some(ref path) = unix_sock_path {
cleanup_unix_socket(path);
}
}
Err(e) => error!("Signal error: {}", e), Err(e) => error!("Signal error: {}", e),
} }

View File

@@ -23,6 +23,149 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::middle_relay::handle_via_middle_proxy;
/// Handle a client connection from any stream type (TCP, Unix socket)
///
/// This is the generic entry point for client handling. Unlike `ClientHandler::new().run()`,
/// it skips TCP-specific socket configuration (TCP_NODELAY, keepalive, TCP_USER_TIMEOUT)
/// which is appropriate for non-TCP streams like Unix sockets.
pub async fn handle_client_stream<S>(
mut stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
stats.increment_connects_all();
debug!(peer = %peer, "New connection (generic stream)");
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake);
let stats_for_timeout = stats.clone();
// For non-TCP streams, use a synthetic local address
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let result = timeout(handshake_timeout, async {
let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls {
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(());
}
let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream);
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, read_half, write_half, peer,
&config, &replay_checker, &rng,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, tls_reader, tls_writer, peer,
&config, &replay_checker, true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool,
local_addr,
).await
} else {
if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(());
}
let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream);
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, read_half, write_half, peer,
&config, &replay_checker, false,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool,
local_addr,
).await
}
}).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => {
stats_for_timeout.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
}
}
}
pub struct ClientHandler; pub struct ClientHandler;
pub struct RunningClientHandler { pub struct RunningClientHandler {
@@ -269,7 +412,7 @@ impl RunningClientHandler {
/// Two modes: /// Two modes:
/// - Direct: TCP relay to TG DC (existing behavior) /// - Direct: TCP relay to TG DC (existing behavior)
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
async fn handle_authenticated_static<R, W>( pub(crate) async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
success: HandshakeSuccess, success: HandshakeSuccess,

View File

@@ -1,13 +1,13 @@
//! Proxy Defs //! Proxy Defs
pub mod client; pub mod client;
pub mod direct_relay; pub(crate) mod direct_relay;
pub mod handshake; pub mod handshake;
pub mod masking; pub mod masking;
pub mod middle_relay; pub(crate) mod middle_relay;
pub mod relay; pub mod relay;
pub use client::ClientHandler; pub use client::{ClientHandler, handle_client_stream};
pub use handshake::*; pub use handshake::*;
pub use masking::*; pub use masking::*;
pub use relay::*; pub use relay::*;

View File

@@ -202,6 +202,51 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Sock
Ok(socket) Ok(socket)
} }
/// Create a Unix socket listener with stale socket detection.
///
/// If the socket file already exists, attempts to connect to it:
/// - If connection succeeds → another instance is running → returns AddrInUse error
/// - If connection fails → stale socket → removes it and binds
#[cfg(unix)]
pub fn create_unix_listener(path: &str) -> Result<std::os::unix::net::UnixListener> {
use std::os::unix::net::UnixListener;
use std::path::Path;
let socket_path = Path::new(path);
if socket_path.exists() {
match std::os::unix::net::UnixStream::connect(socket_path) {
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::AddrInUse,
format!("Unix socket {} is already in use by another process", path)
));
}
Err(_) => {
debug!("Removing stale Unix socket: {}", path);
std::fs::remove_file(socket_path)?;
}
}
}
let listener = UnixListener::bind(socket_path)?;
listener.set_nonblocking(true)?;
debug!("Created Unix socket listener at {}", path);
Ok(listener)
}
/// Remove Unix socket file on shutdown
#[cfg(unix)]
pub fn cleanup_unix_socket(path: &str) {
if std::path::Path::new(path).exists() {
match std::fs::remove_file(path) {
Ok(_) => debug!("Cleaned up Unix socket: {}", path),
Err(e) => debug!("Failed to remove Unix socket {}: {}", path, e),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;