Tschuss Status Quo - Hallo, Zukunft!
This commit is contained in:
Alexey
2025-12-30 05:08:05 +03:00
parent 44169441b4
commit 3d9150a074
33 changed files with 6079 additions and 0 deletions

9
src/transport/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
//! Transport layer: connection pooling, socket utilities, proxy protocol
pub mod pool;
pub mod proxy_protocol;
pub mod socket;
pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*;

338
src/transport/pool.rs Normal file
View File

@@ -0,0 +1,338 @@
//! Connection Pool
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::timeout;
use parking_lot::RwLock;
use tracing::{debug, warn};
use crate::error::{ProxyError, Result};
use super::socket::configure_tcp_socket;
/// A pooled connection with metadata
struct PooledConnection {
stream: TcpStream,
created_at: Instant,
}
/// Internal pool state for a single endpoint
struct PoolInner {
/// Available connections
connections: Vec<PooledConnection>,
/// Number of connections being established
pending: usize,
}
impl PoolInner {
fn new() -> Self {
Self {
connections: Vec::new(),
pending: 0,
}
}
}
/// Connection pool configuration
#[derive(Debug, Clone)]
pub struct PoolConfig {
/// Maximum connections per endpoint
pub max_connections: usize,
/// Connection timeout
pub connect_timeout: Duration,
/// Maximum idle time before connection is dropped
pub max_idle_time: Duration,
/// Enable TCP keepalive
pub keepalive: bool,
/// Keepalive interval
pub keepalive_interval: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: 64,
connect_timeout: Duration::from_secs(10),
max_idle_time: Duration::from_secs(60),
keepalive: true,
keepalive_interval: Duration::from_secs(40),
}
}
}
/// Thread-safe connection pool
pub struct ConnectionPool {
/// Per-endpoint pools
pools: RwLock<HashMap<SocketAddr, Arc<Mutex<PoolInner>>>>,
/// Configuration
config: PoolConfig,
}
impl ConnectionPool {
/// Create new connection pool with default config
pub fn new() -> Self {
Self::with_config(PoolConfig::default())
}
/// Create connection pool with custom config
pub fn with_config(config: PoolConfig) -> Self {
Self {
pools: RwLock::new(HashMap::new()),
config,
}
}
/// Get or create pool for an endpoint
fn get_or_create_pool(&self, addr: SocketAddr) -> Arc<Mutex<PoolInner>> {
// Fast path with read lock
{
let pools = self.pools.read();
if let Some(pool) = pools.get(&addr) {
return Arc::clone(pool);
}
}
// Slow path with write lock
let mut pools = self.pools.write();
pools.entry(addr)
.or_insert_with(|| Arc::new(Mutex::new(PoolInner::new())))
.clone()
}
/// Get a connection to the specified address
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
let pool = self.get_or_create_pool(addr);
// Try to get an existing connection
{
let mut inner = pool.lock().await;
// Remove stale connections
let now = Instant::now();
inner.connections.retain(|c| {
now.duration_since(c.created_at) < self.config.max_idle_time
});
// Try to find a usable connection
while let Some(conn) = inner.connections.pop() {
// Check if connection is still alive
if is_connection_alive(&conn.stream) {
debug!(addr = %addr, "Reusing pooled connection");
return Ok(conn.stream);
}
debug!(addr = %addr, "Discarding dead pooled connection");
}
// Check if we can create a new connection
let total = inner.connections.len() + inner.pending;
if total >= self.config.max_connections {
return Err(ProxyError::ConnectionTimeout {
addr: addr.to_string()
});
}
inner.pending += 1;
}
// Create new connection
debug!(addr = %addr, "Creating new connection");
let result = self.create_connection(addr).await;
// Decrement pending count
{
let mut inner = pool.lock().await;
inner.pending = inner.pending.saturating_sub(1);
}
result
}
/// Create a new connection to the address
async fn create_connection(&self, addr: SocketAddr) -> Result<TcpStream> {
let connect_future = TcpStream::connect(addr);
let stream = timeout(self.config.connect_timeout, connect_future)
.await
.map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string()
})?
.map_err(|e| {
if e.kind() == std::io::ErrorKind::ConnectionRefused {
ProxyError::ConnectionRefused { addr: addr.to_string() }
} else {
ProxyError::Io(e)
}
})?;
// Configure socket
configure_tcp_socket(
&stream,
self.config.keepalive,
self.config.keepalive_interval,
)?;
Ok(stream)
}
/// Return a connection to the pool
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
let pool = self.get_or_create_pool(addr);
let mut inner = pool.lock().await;
if inner.connections.len() < self.config.max_connections {
inner.connections.push(PooledConnection {
stream,
created_at: Instant::now(),
});
debug!(addr = %addr, pool_size = inner.connections.len(), "Returned connection to pool");
} else {
debug!(addr = %addr, "Pool full, dropping connection");
}
}
/// Close all pooled connections
pub async fn close_all(&self) {
let pools = self.pools.read();
for (addr, pool) in pools.iter() {
let mut inner = pool.lock().await;
let count = inner.connections.len();
inner.connections.clear();
debug!(addr = %addr, count = count, "Closed pooled connections");
}
}
/// Get pool statistics
pub async fn stats(&self) -> PoolStats {
let pools = self.pools.read();
let mut total_connections = 0;
let mut total_pending = 0;
let mut endpoints = 0;
for pool in pools.values() {
let inner = pool.lock().await;
total_connections += inner.connections.len();
total_pending += inner.pending;
endpoints += 1;
}
PoolStats {
endpoints,
total_connections,
total_pending,
}
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
/// Pool statistics
#[derive(Debug, Clone)]
pub struct PoolStats {
pub endpoints: usize,
pub total_connections: usize,
pub total_pending: usize,
}
/// Check if a TCP connection is still alive (non-blocking)
fn is_connection_alive(stream: &TcpStream) -> bool {
// Try a non-blocking read to check connection state
let mut buf = [0u8; 1];
match stream.try_read(&mut buf) {
Ok(0) => false, // Connection closed
Ok(_) => true, // Data available (shouldn't happen, but connection is alive)
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => true, // No data, but alive
Err(_) => false, // Some error, assume dead
}
}
/// Connection pool with custom initialization
pub struct InitializingPool<F> {
pool: ConnectionPool,
init_fn: F,
}
impl<F, Fut> InitializingPool<F>
where
F: Fn(TcpStream, SocketAddr) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<TcpStream>> + Send,
{
/// Create pool with initialization function
pub fn new(config: PoolConfig, init_fn: F) -> Self {
Self {
pool: ConnectionPool::with_config(config),
init_fn,
}
}
/// Get an initialized connection
pub async fn get(&self, addr: SocketAddr) -> Result<TcpStream> {
let stream = self.pool.get(addr).await?;
(self.init_fn)(stream, addr).await
}
/// Return connection to pool
pub async fn put(&self, addr: SocketAddr, stream: TcpStream) {
self.pool.put(addr, stream).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_pool_basic() {
// Start a test server
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
// Accept connections in background
tokio::spawn(async move {
loop {
let _ = listener.accept().await;
}
});
let pool = ConnectionPool::new();
// Get a connection
let conn1 = pool.get(addr).await.unwrap();
// Return it to pool
pool.put(addr, conn1).await;
// Get again (should reuse)
let _conn2 = pool.get(addr).await.unwrap();
let stats = pool.stats().await;
assert_eq!(stats.endpoints, 1);
}
#[tokio::test]
async fn test_pool_connection_refused() {
let pool = ConnectionPool::with_config(PoolConfig {
connect_timeout: Duration::from_millis(100),
..Default::default()
});
// Try to connect to a port that's not listening
let result = pool.get("127.0.0.1:1".parse().unwrap()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_pool_stats() {
let pool = ConnectionPool::new();
let stats = pool.stats().await;
assert_eq!(stats.endpoints, 0);
assert_eq!(stats.total_connections, 0);
}
}

View File

@@ -0,0 +1,381 @@
//! HAProxy PROXY protocol V1/V2
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::error::{ProxyError, Result};
/// PROXY protocol v1 signature
const PROXY_V1_SIGNATURE: &[u8] = b"PROXY ";
/// PROXY protocol v2 signature
const PROXY_V2_SIGNATURE: &[u8] = &[
0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a,
0x51, 0x55, 0x49, 0x54, 0x0a
];
/// Minimum length for v1 detection
const PROXY_V1_MIN_LEN: usize = 6;
/// Minimum length for v2 header
const PROXY_V2_MIN_LEN: usize = 16;
/// Address families for v2
mod address_family {
pub const UNSPEC: u8 = 0x0;
pub const INET: u8 = 0x1;
pub const INET6: u8 = 0x2;
}
/// Information extracted from PROXY protocol header
#[derive(Debug, Clone)]
pub struct ProxyProtocolInfo {
/// Source (client) address
pub src_addr: SocketAddr,
/// Destination address (optional)
pub dst_addr: Option<SocketAddr>,
/// Protocol version used (1 or 2)
pub version: u8,
}
impl ProxyProtocolInfo {
/// Create info with just source address
pub fn new(src_addr: SocketAddr) -> Self {
Self {
src_addr,
dst_addr: None,
version: 0,
}
}
}
/// Parse PROXY protocol header from a stream
///
/// Returns the parsed info or an error if the header is invalid.
/// The stream position is advanced past the header.
pub async fn parse_proxy_protocol<R: AsyncRead + Unpin>(
reader: &mut R,
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
// Read enough bytes to detect version
let mut header = [0u8; PROXY_V2_MIN_LEN];
reader.read_exact(&mut header[..PROXY_V1_MIN_LEN]).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
// Check for v1
if header[..PROXY_V1_MIN_LEN] == PROXY_V1_SIGNATURE[..] {
return parse_v1(reader, default_peer).await;
}
// Read rest for v2 detection
reader.read_exact(&mut header[PROXY_V1_MIN_LEN..]).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
// Check for v2
if header[..12] == PROXY_V2_SIGNATURE[..] {
return parse_v2(reader, &header, default_peer).await;
}
Err(ProxyError::InvalidProxyProtocol)
}
/// Parse PROXY protocol v1
async fn parse_v1<R: AsyncRead + Unpin>(
reader: &mut R,
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
// Read until CRLF (max 107 bytes total for v1)
let mut line = Vec::with_capacity(128);
line.extend_from_slice(PROXY_V1_SIGNATURE);
loop {
let mut byte = [0u8];
reader.read_exact(&mut byte).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
line.push(byte[0]);
if line.ends_with(b"\r\n") {
break;
}
if line.len() > 256 {
return Err(ProxyError::InvalidProxyProtocol);
}
}
// Parse the line: PROXY TCP4/TCP6/UNKNOWN src_ip dst_ip src_port dst_port
let line_str = std::str::from_utf8(&line[PROXY_V1_MIN_LEN..line.len() - 2])
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let parts: Vec<&str> = line_str.split_whitespace().collect();
if parts.is_empty() {
return Err(ProxyError::InvalidProxyProtocol);
}
match parts[0] {
"TCP4" | "TCP6" if parts.len() >= 5 => {
let src_ip: IpAddr = parts[1].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let dst_ip: IpAddr = parts[2].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let src_port: u16 = parts[3].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
let dst_port: u16 = parts[4].parse()
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(src_ip, src_port),
dst_addr: Some(SocketAddr::new(dst_ip, dst_port)),
version: 1,
})
}
"UNKNOWN" => {
// UNKNOWN means no address info, use default
Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 1,
})
}
_ => Err(ProxyError::InvalidProxyProtocol),
}
}
/// Parse PROXY protocol v2
async fn parse_v2<R: AsyncRead + Unpin>(
reader: &mut R,
header: &[u8; PROXY_V2_MIN_LEN],
default_peer: SocketAddr,
) -> Result<ProxyProtocolInfo> {
let version_command = header[12];
let version = version_command >> 4;
let command = version_command & 0x0f;
// Must be version 2
if version != 2 {
return Err(ProxyError::InvalidProxyProtocol);
}
let family_protocol = header[13];
let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize;
// Read address data
let mut addr_data = vec![0u8; addr_len];
if addr_len > 0 {
reader.read_exact(&mut addr_data).await
.map_err(|_| ProxyError::InvalidProxyProtocol)?;
}
// LOCAL command (0x0) - use default peer
if command == 0 {
return Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 2,
});
}
// PROXY command (0x1) - parse addresses
if command != 1 {
return Err(ProxyError::InvalidProxyProtocol);
}
let family = family_protocol >> 4;
match family {
address_family::INET if addr_len >= 12 => {
// IPv4: 4 + 4 + 2 + 2 = 12 bytes
let src_ip = Ipv4Addr::new(
addr_data[0], addr_data[1],
addr_data[2], addr_data[3]
);
let dst_ip = Ipv4Addr::new(
addr_data[4], addr_data[5],
addr_data[6], addr_data[7]
);
let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]);
let dst_port = u16::from_be_bytes([addr_data[10], addr_data[11]]);
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
dst_addr: Some(SocketAddr::new(IpAddr::V4(dst_ip), dst_port)),
version: 2,
})
}
address_family::INET6 if addr_len >= 36 => {
// IPv6: 16 + 16 + 2 + 2 = 36 bytes
let src_ip = Ipv6Addr::from(
<[u8; 16]>::try_from(&addr_data[0..16]).unwrap()
);
let dst_ip = Ipv6Addr::from(
<[u8; 16]>::try_from(&addr_data[16..32]).unwrap()
);
let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]);
let dst_port = u16::from_be_bytes([addr_data[34], addr_data[35]]);
Ok(ProxyProtocolInfo {
src_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
dst_addr: Some(SocketAddr::new(IpAddr::V6(dst_ip), dst_port)),
version: 2,
})
}
address_family::UNSPEC => {
Ok(ProxyProtocolInfo {
src_addr: default_peer,
dst_addr: None,
version: 2,
})
}
_ => Err(ProxyError::InvalidProxyProtocol),
}
}
/// Builder for PROXY protocol v1 header
pub struct ProxyProtocolV1Builder {
family: &'static str,
src_addr: Option<SocketAddr>,
dst_addr: Option<SocketAddr>,
}
impl ProxyProtocolV1Builder {
pub fn new() -> Self {
Self {
family: "UNKNOWN",
src_addr: None,
dst_addr: None,
}
}
pub fn tcp4(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
self.family = "TCP4";
self.src_addr = Some(src);
self.dst_addr = Some(dst);
self
}
pub fn tcp6(mut self, src: SocketAddr, dst: SocketAddr) -> Self {
self.family = "TCP6";
self.src_addr = Some(src);
self.dst_addr = Some(dst);
self
}
pub fn build(&self) -> Vec<u8> {
match (self.src_addr, self.dst_addr) {
(Some(src), Some(dst)) => {
format!(
"PROXY {} {} {} {} {}\r\n",
self.family,
src.ip(),
dst.ip(),
src.port(),
dst.port()
).into_bytes()
}
_ => b"PROXY UNKNOWN\r\n".to_vec(),
}
}
}
impl Default for ProxyProtocolV1Builder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
let default = "0.0.0.0:0".parse().unwrap();
// Simulate that we've already read the signature
let info = parse_v1(&mut cursor, default).await.unwrap();
assert_eq!(info.version, 1);
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
assert_eq!(info.src_addr.port(), 12345);
assert!(info.dst_addr.is_some());
}
#[tokio::test]
async fn test_parse_v1_unknown() {
let header = b"PROXY UNKNOWN\r\n";
let mut cursor = Cursor::new(&header[PROXY_V1_MIN_LEN..]);
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
let info = parse_v1(&mut cursor, default).await.unwrap();
assert_eq!(info.version, 1);
assert_eq!(info.src_addr, default);
}
#[tokio::test]
async fn test_parse_v2_tcp4() {
// v2 header for TCP4
let mut header = [0u8; 16];
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
header[12] = 0x21; // v2, PROXY command
header[13] = 0x11; // AF_INET, STREAM
header[14] = 0x00;
header[15] = 0x0c; // 12 bytes of address data
let addr_data = [
192, 168, 1, 1, // src IP
10, 0, 0, 1, // dst IP
0x30, 0x39, // src port (12345)
0x01, 0xbb, // dst port (443)
];
let mut cursor = Cursor::new(addr_data.to_vec());
let default = "0.0.0.0:0".parse().unwrap();
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
assert_eq!(info.version, 2);
assert_eq!(info.src_addr.ip().to_string(), "192.168.1.1");
assert_eq!(info.src_addr.port(), 12345);
}
#[tokio::test]
async fn test_parse_v2_local() {
let mut header = [0u8; 16];
header[..12].copy_from_slice(PROXY_V2_SIGNATURE);
header[12] = 0x20; // v2, LOCAL command
header[13] = 0x00;
header[14] = 0x00;
header[15] = 0x00; // 0 bytes of address data
let mut cursor = Cursor::new(Vec::new());
let default: SocketAddr = "1.2.3.4:5678".parse().unwrap();
let info = parse_v2(&mut cursor, &header, default).await.unwrap();
assert_eq!(info.version, 2);
assert_eq!(info.src_addr, default);
}
#[test]
fn test_v1_builder() {
let src: SocketAddr = "192.168.1.1:12345".parse().unwrap();
let dst: SocketAddr = "10.0.0.1:443".parse().unwrap();
let header = ProxyProtocolV1Builder::new()
.tcp4(src, dst)
.build();
let expected = b"PROXY TCP4 192.168.1.1 10.0.0.1 12345 443\r\n";
assert_eq!(header, expected);
}
#[test]
fn test_v1_builder_unknown() {
let header = ProxyProtocolV1Builder::new().build();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
}

230
src/transport/socket.rs Normal file
View File

@@ -0,0 +1,230 @@
//! TCP Socket Configuration
use std::io::Result;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpStream;
use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol};
use tracing::debug;
/// Configure TCP socket with recommended settings for proxy use
pub fn configure_tcp_socket(
stream: &TcpStream,
keepalive: bool,
keepalive_interval: Duration,
) -> Result<()> {
let socket = socket2::SockRef::from(stream);
// Disable Nagle's algorithm for lower latency
socket.set_nodelay(true)?;
// Set keepalive if enabled
if keepalive {
let keepalive = TcpKeepalive::new()
.with_time(keepalive_interval);
// Platform-specific keepalive settings
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
let keepalive = keepalive.with_interval(keepalive_interval);
socket.set_tcp_keepalive(&keepalive)?;
}
// Set buffer sizes
set_buffer_sizes(&socket, 65536, 65536)?;
Ok(())
}
/// Set socket buffer sizes
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
// These may fail on some systems, so we ignore errors
let _ = socket.set_recv_buffer_size(recv);
let _ = socket.set_send_buffer_size(send);
Ok(())
}
/// Configure socket for accepting client connections
pub fn configure_client_socket(
stream: &TcpStream,
keepalive_secs: u64,
ack_timeout_secs: u64,
) -> Result<()> {
let socket = socket2::SockRef::from(stream);
// Disable Nagle's algorithm
socket.set_nodelay(true)?;
// Set keepalive
let keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(keepalive_secs));
#[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))]
let keepalive = keepalive.with_interval(Duration::from_secs(keepalive_secs));
socket.set_tcp_keepalive(&keepalive)?;
// Set TCP user timeout (Linux only)
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
}
}
Ok(())
}
/// Set socket to send RST on close (for masking)
pub fn set_linger_zero(stream: &TcpStream) -> Result<()> {
let socket = socket2::SockRef::from(stream);
socket.set_linger(Some(Duration::ZERO))?;
Ok(())
}
/// Create a new TCP socket for outgoing connections
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
// Set non-blocking
socket.set_nonblocking(true)?;
// Disable Nagle
socket.set_nodelay(true)?;
Ok(socket)
}
/// Get local address of a socket
pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
stream.local_addr().ok()
}
/// Get peer address of a socket
pub fn get_peer_addr(stream: &TcpStream) -> Option<SocketAddr> {
stream.peer_addr().ok()
}
/// Check if address is IPv6
pub fn is_ipv6(addr: &SocketAddr) -> bool {
addr.is_ipv6()
}
/// Parse IPv4-mapped IPv6 address to IPv4
pub fn normalize_ip(addr: SocketAddr) -> SocketAddr {
match addr {
SocketAddr::V6(v6) => {
if let Some(v4) = v6.ip().to_ipv4_mapped() {
SocketAddr::new(std::net::IpAddr::V4(v4), v6.port())
} else {
addr
}
}
_ => addr,
}
}
/// Socket options for server listening
#[derive(Debug, Clone)]
pub struct ListenOptions {
/// Enable SO_REUSEADDR
pub reuse_addr: bool,
/// Enable SO_REUSEPORT (Linux/BSD)
pub reuse_port: bool,
/// Backlog size
pub backlog: u32,
/// IPv6 only (disable dual-stack)
pub ipv6_only: bool,
}
impl Default for ListenOptions {
fn default() -> Self {
Self {
reuse_addr: true,
reuse_port: true,
backlog: 1024,
ipv6_only: false,
}
}
}
/// Create a listening socket with the specified options
pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Socket> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
if options.reuse_addr {
socket.set_reuse_address(true)?;
}
#[cfg(unix)]
if options.reuse_port {
socket.set_reuse_port(true)?;
}
if addr.is_ipv6() && options.ipv6_only {
socket.set_only_v6(true)?;
}
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(options.backlog as i32)?;
debug!(addr = %addr, "Created listening socket");
Ok(socket)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_configure_socket() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = TcpStream::connect(addr).await.unwrap();
configure_tcp_socket(&stream, true, Duration::from_secs(30)).unwrap();
}
#[test]
fn test_normalize_ip() {
// IPv4 stays IPv4
let v4: SocketAddr = "192.168.1.1:8080".parse().unwrap();
assert_eq!(normalize_ip(v4), v4);
// Pure IPv6 stays IPv6
let v6: SocketAddr = "[::1]:8080".parse().unwrap();
assert_eq!(normalize_ip(v6), v6);
}
#[test]
fn test_listen_options_default() {
let opts = ListenOptions::default();
assert!(opts.reuse_addr);
assert!(opts.reuse_port);
assert_eq!(opts.backlog, 1024);
}
}