- Fixed tests that failed to compile due to mismatched generic parameters of HandshakeResult:
- Changed `HandshakeResult<i32>` to `HandshakeResult<i32, (), ()>`
- Changed `HandshakeResult::BadClient` to `HandshakeResult::BadClient { reader: (), writer: () }`
- Added Zeroize for all structures holding key material:
- AesCbc – key and IV are zeroized on drop
- SecureRandomInner – PRNG output buffer is zeroized on drop; local key copy in constructor is zeroized immediately after being passed to the cipher
- ObfuscationParams – all four key‑material fields are zeroized on drop
- HandshakeSuccess – all four key‑material fields are zeroized on drop
- Added protocol‑requirement documentation for legacy hashes (CodeQL suppression) in hash.rs (MD5/SHA‑1)
- Added documentation for zeroize limitations of AesCtr (opaque cipher state) in aes.rs
- Implemented silent‑mode logging and refactored initialization:
- Added LogLevel enum to config and CLI flags --silent / --log-level
- Added parse_cli() to handle --silent, --log-level, --help
- Restructured main.rs initialization order: CLI → config load → determine log level → init tracing
- Errors before tracing initialization are printed via eprintln!
- Proxy links (tg://) are printed via println! – always visible regardless of log level
- Configuration summary and operational messages are logged via info! (suppressed in silent mode)
- Connection processing errors are lowered to debug! (hidden in silent mode)
- Warning about default tls_domain moved to main (after tracing init)
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
416 lines
12 KiB
Rust
416 lines
12 KiB
Rust
//! Error Types
|
|
|
|
use std::fmt;
|
|
use std::net::SocketAddr;
|
|
use thiserror::Error;
|
|
|
|
// ============= Stream Errors =============
|
|
|
|
/// Errors specific to stream I/O operations
|
|
#[derive(Debug)]
|
|
pub enum StreamError {
|
|
/// Partial read: got fewer bytes than expected
|
|
PartialRead {
|
|
expected: usize,
|
|
got: usize,
|
|
},
|
|
/// Partial write: wrote fewer bytes than expected
|
|
PartialWrite {
|
|
expected: usize,
|
|
written: usize,
|
|
},
|
|
/// Stream is in poisoned state and cannot be used
|
|
Poisoned {
|
|
reason: String,
|
|
},
|
|
/// Buffer overflow: attempted to buffer more than allowed
|
|
BufferOverflow {
|
|
limit: usize,
|
|
attempted: usize,
|
|
},
|
|
/// Invalid frame format
|
|
InvalidFrame {
|
|
details: String,
|
|
},
|
|
/// Unexpected end of stream
|
|
UnexpectedEof,
|
|
/// Underlying I/O error
|
|
Io(std::io::Error),
|
|
}
|
|
|
|
impl fmt::Display for StreamError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::PartialRead { expected, got } => {
|
|
write!(f, "partial read: expected {} bytes, got {}", expected, got)
|
|
}
|
|
Self::PartialWrite { expected, written } => {
|
|
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
|
|
}
|
|
Self::Poisoned { reason } => {
|
|
write!(f, "stream poisoned: {}", reason)
|
|
}
|
|
Self::BufferOverflow { limit, attempted } => {
|
|
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
|
|
}
|
|
Self::InvalidFrame { details } => {
|
|
write!(f, "invalid frame: {}", details)
|
|
}
|
|
Self::UnexpectedEof => {
|
|
write!(f, "unexpected end of stream")
|
|
}
|
|
Self::Io(e) => {
|
|
write!(f, "I/O error: {}", e)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for StreamError {
|
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
match self {
|
|
Self::Io(e) => Some(e),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<std::io::Error> for StreamError {
|
|
fn from(err: std::io::Error) -> Self {
|
|
Self::Io(err)
|
|
}
|
|
}
|
|
|
|
impl From<StreamError> for std::io::Error {
|
|
fn from(err: StreamError) -> Self {
|
|
match err {
|
|
StreamError::Io(e) => e,
|
|
StreamError::UnexpectedEof => {
|
|
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
|
|
}
|
|
StreamError::Poisoned { .. } => {
|
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
|
}
|
|
StreamError::BufferOverflow { .. } => {
|
|
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
|
|
}
|
|
StreamError::InvalidFrame { .. } => {
|
|
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
|
|
}
|
|
StreamError::PartialRead { .. } | StreamError::PartialWrite { .. } => {
|
|
std::io::Error::new(std::io::ErrorKind::Other, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============= Recoverable Trait =============
|
|
|
|
/// Trait for errors that may be recoverable
|
|
pub trait Recoverable {
|
|
/// Check if error is recoverable (can retry operation)
|
|
fn is_recoverable(&self) -> bool;
|
|
|
|
/// Check if connection can continue after this error
|
|
fn can_continue(&self) -> bool;
|
|
}
|
|
|
|
impl Recoverable for StreamError {
|
|
fn is_recoverable(&self) -> bool {
|
|
match self {
|
|
Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
|
|
Self::Io(e) => matches!(
|
|
e.kind(),
|
|
std::io::ErrorKind::WouldBlock
|
|
| std::io::ErrorKind::Interrupted
|
|
| std::io::ErrorKind::TimedOut
|
|
),
|
|
Self::Poisoned { .. }
|
|
| Self::BufferOverflow { .. }
|
|
| Self::InvalidFrame { .. }
|
|
| Self::UnexpectedEof => false,
|
|
}
|
|
}
|
|
|
|
fn can_continue(&self) -> bool {
|
|
match self {
|
|
Self::Poisoned { .. } => false,
|
|
Self::UnexpectedEof => false,
|
|
Self::BufferOverflow { .. } => false,
|
|
_ => true,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Recoverable for std::io::Error {
|
|
fn is_recoverable(&self) -> bool {
|
|
matches!(
|
|
self.kind(),
|
|
std::io::ErrorKind::WouldBlock
|
|
| std::io::ErrorKind::Interrupted
|
|
| std::io::ErrorKind::TimedOut
|
|
)
|
|
}
|
|
|
|
fn can_continue(&self) -> bool {
|
|
!matches!(
|
|
self.kind(),
|
|
std::io::ErrorKind::BrokenPipe
|
|
| std::io::ErrorKind::ConnectionReset
|
|
| std::io::ErrorKind::ConnectionAborted
|
|
| std::io::ErrorKind::NotConnected
|
|
)
|
|
}
|
|
}
|
|
|
|
// ============= Main Proxy Errors =============
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum ProxyError {
|
|
// ============= Crypto Errors =============
|
|
|
|
#[error("Crypto error: {0}")]
|
|
Crypto(String),
|
|
|
|
#[error("Invalid key length: expected {expected}, got {got}")]
|
|
InvalidKeyLength { expected: usize, got: usize },
|
|
|
|
// ============= Stream Errors =============
|
|
|
|
#[error("Stream error: {0}")]
|
|
Stream(#[from] StreamError),
|
|
|
|
// ============= Protocol Errors =============
|
|
|
|
#[error("Invalid handshake: {0}")]
|
|
InvalidHandshake(String),
|
|
|
|
#[error("Invalid protocol tag: {0:02x?}")]
|
|
InvalidProtoTag([u8; 4]),
|
|
|
|
#[error("Invalid TLS record: type={record_type}, version={version:02x?}")]
|
|
InvalidTlsRecord { record_type: u8, version: [u8; 2] },
|
|
|
|
#[error("Replay attack detected from {addr}")]
|
|
ReplayAttack { addr: SocketAddr },
|
|
|
|
#[error("Time skew detected: client={client_time}, server={server_time}")]
|
|
TimeSkew { client_time: u32, server_time: u32 },
|
|
|
|
#[error("Invalid message length: {len} (min={min}, max={max})")]
|
|
InvalidMessageLength { len: usize, min: usize, max: usize },
|
|
|
|
#[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")]
|
|
ChecksumMismatch { expected: u32, got: u32 },
|
|
|
|
#[error("Sequence number mismatch: expected={expected}, got={got}")]
|
|
SeqNoMismatch { expected: i32, got: i32 },
|
|
|
|
#[error("TLS handshake failed: {reason}")]
|
|
TlsHandshakeFailed { reason: String },
|
|
|
|
#[error("Telegram handshake timeout")]
|
|
TgHandshakeTimeout,
|
|
|
|
// ============= Network Errors =============
|
|
|
|
#[error("Connection timeout to {addr}")]
|
|
ConnectionTimeout { addr: String },
|
|
|
|
#[error("Connection refused by {addr}")]
|
|
ConnectionRefused { addr: String },
|
|
|
|
#[error("IO error: {0}")]
|
|
Io(#[from] std::io::Error),
|
|
|
|
// ============= Proxy Protocol Errors =============
|
|
|
|
#[error("Invalid proxy protocol header")]
|
|
InvalidProxyProtocol,
|
|
|
|
#[error("Proxy error: {0}")]
|
|
Proxy(String),
|
|
|
|
// ============= Config Errors =============
|
|
|
|
#[error("Config error: {0}")]
|
|
Config(String),
|
|
|
|
#[error("Invalid secret for user {user}: {reason}")]
|
|
InvalidSecret { user: String, reason: String },
|
|
|
|
// ============= User Errors =============
|
|
|
|
#[error("User {user} expired")]
|
|
UserExpired { user: String },
|
|
|
|
#[error("User {user} exceeded connection limit")]
|
|
ConnectionLimitExceeded { user: String },
|
|
|
|
#[error("User {user} exceeded data quota")]
|
|
DataQuotaExceeded { user: String },
|
|
|
|
#[error("Unknown user")]
|
|
UnknownUser,
|
|
|
|
#[error("Rate limited")]
|
|
RateLimited,
|
|
|
|
// ============= General Errors =============
|
|
|
|
#[error("Internal error: {0}")]
|
|
Internal(String),
|
|
}
|
|
|
|
impl Recoverable for ProxyError {
|
|
fn is_recoverable(&self) -> bool {
|
|
match self {
|
|
Self::Stream(e) => e.is_recoverable(),
|
|
Self::Io(e) => e.is_recoverable(),
|
|
Self::ConnectionTimeout { .. } => true,
|
|
Self::RateLimited => true,
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
fn can_continue(&self) -> bool {
|
|
match self {
|
|
Self::Stream(e) => e.can_continue(),
|
|
Self::Io(e) => e.can_continue(),
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convenient Result type alias
|
|
pub type Result<T> = std::result::Result<T, ProxyError>;
|
|
|
|
/// Result type for stream operations
|
|
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
|
|
|
/// Result with optional bad client handling
|
|
#[derive(Debug)]
|
|
pub enum HandshakeResult<T, R, W> {
|
|
/// Handshake succeeded
|
|
Success(T),
|
|
/// Client failed validation, needs masking. Returns ownership of streams.
|
|
BadClient { reader: R, writer: W },
|
|
/// Error occurred
|
|
Error(ProxyError),
|
|
}
|
|
|
|
impl<T, R, W> HandshakeResult<T, R, W> {
|
|
/// Check if successful
|
|
pub fn is_success(&self) -> bool {
|
|
matches!(self, HandshakeResult::Success(_))
|
|
}
|
|
|
|
/// Check if bad client
|
|
pub fn is_bad_client(&self) -> bool {
|
|
matches!(self, HandshakeResult::BadClient { .. })
|
|
}
|
|
|
|
/// Map the success value
|
|
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
|
|
match self {
|
|
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
|
|
HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
|
|
HandshakeResult::Error(e) => HandshakeResult::Error(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
|
|
fn from(err: ProxyError) -> Self {
|
|
HandshakeResult::Error(err)
|
|
}
|
|
}
|
|
|
|
impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
|
|
fn from(err: std::io::Error) -> Self {
|
|
HandshakeResult::Error(ProxyError::Io(err))
|
|
}
|
|
}
|
|
|
|
impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
|
|
fn from(err: StreamError) -> Self {
|
|
HandshakeResult::Error(ProxyError::Stream(err))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_stream_error_display() {
|
|
let err = StreamError::PartialRead { expected: 100, got: 50 };
|
|
assert!(err.to_string().contains("100"));
|
|
assert!(err.to_string().contains("50"));
|
|
|
|
let err = StreamError::Poisoned { reason: "test".into() };
|
|
assert!(err.to_string().contains("test"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_stream_error_recoverable() {
|
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
|
|
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
|
|
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
|
|
assert!(!StreamError::UnexpectedEof.is_recoverable());
|
|
}
|
|
|
|
#[test]
|
|
fn test_stream_error_can_continue() {
|
|
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
|
|
assert!(!StreamError::UnexpectedEof.can_continue());
|
|
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
|
|
}
|
|
|
|
#[test]
|
|
fn test_stream_error_to_io_error() {
|
|
let stream_err = StreamError::UnexpectedEof;
|
|
let io_err: std::io::Error = stream_err.into();
|
|
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
|
|
}
|
|
|
|
#[test]
|
|
fn test_handshake_result() {
|
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
|
assert!(success.is_success());
|
|
assert!(!success.is_bad_client());
|
|
|
|
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
|
|
assert!(!bad.is_success());
|
|
assert!(bad.is_bad_client());
|
|
}
|
|
|
|
#[test]
|
|
fn test_handshake_result_map() {
|
|
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
|
|
let mapped = success.map(|x| x * 2);
|
|
|
|
match mapped {
|
|
HandshakeResult::Success(v) => assert_eq!(v, 84),
|
|
_ => panic!("Expected success"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_proxy_error_recoverable() {
|
|
let err = ProxyError::RateLimited;
|
|
assert!(err.is_recoverable());
|
|
|
|
let err = ProxyError::InvalidHandshake("bad".into());
|
|
assert!(!err.is_recoverable());
|
|
}
|
|
|
|
#[test]
|
|
fn test_error_display() {
|
|
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
|
|
assert!(err.to_string().contains("1.2.3.4:443"));
|
|
|
|
let err = ProxyError::InvalidProxyProtocol;
|
|
assert!(err.to_string().contains("proxy protocol"));
|
|
}
|
|
} |