Middle Proxy Fixes

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-13 16:09:33 +03:00
parent e62b41ae64
commit de28655dd2
3 changed files with 301 additions and 275 deletions

View File

@@ -338,176 +338,221 @@
/// - CDN DCs (203+) work because ME knows their internal addresses
/// - We pass raw client MTProto bytes in RPC_PROXY_REQ envelope
/// - ME returns responses in RPC_PROXY_ANS envelope
async fn handle_via_middle_proxy<R, W>(
mut client_reader: CryptoReader<R>,
mut client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
me_pool: Arc<MePool>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let peer = success.peer;
info!(
user = %user,
peer = %peer,
dc = success.dc_idx,
proto = ?success.proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
// Register this client connection in ME demux registry
let (conn_id, mut me_rx) = me_pool.registry().register().await;
// Our listening address for RPC_PROXY_REQ metadata
let our_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse().unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
let proto_flags = proto_flags_for_tag(Some(success.proto_tag as u32));
let is_secure = matches!(success.proto_tag, ProtoTag::Secure);
debug!(user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started");
// Bidirectional relay loop: client ↔ ME pool
//
// C→S direction: read intermediate frame from client, send payload via RPC_PROXY_REQ
// S→C direction: receive raw payload from ME, add client intermediate framing
//
// We use tokio::select! to handle both directions concurrently.
// Unlike direct mode (copy_bidirectional on two TCP streams),
// here one side is a channel (mpsc::Receiver), not a stream.
let mut client_closed = false;
let mut server_closed = false;
let result: Result<()> = loop {
tokio::select! {
// C→S: client sends one intermediate frame, we forward payload to ME
read_result = async {
let mut len_buf = [0u8; 4];
client_reader.read_exact(&mut len_buf).await?;
let raw_len = u32::from_le_bytes(len_buf);
let payload_len = (raw_len & 0x7fff_ffff) as usize;
if payload_len > MAX_MSG_LEN {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("client frame too large: {}", payload_len),
));
}
async fn handle_via_middle_proxy<R, W>(
mut client_reader: CryptoReader<R>,
mut client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
me_pool: Arc<MePool>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let peer = success.peer;
let mut payload = vec![0u8; payload_len];
client_reader.read_exact(&mut payload).await?;
Ok::<Vec<u8>, std::io::Error>(payload)
}, if !client_closed => {
match read_result {
Ok(payload) => {
trace!(conn_id, bytes = payload.len(), "C frame -> ME payload");
stats.add_user_octets_from(&user, payload.len() as u64);
if let Err(e) = me_pool.send_proxy_req(
conn_id, peer, our_addr, &payload, proto_flags
).await {
break Err(e);
info!(
user = %user,
peer = %peer,
dc = success.dc_idx,
proto = ?success.proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
let (conn_id, mut me_rx) = me_pool.registry().register().await;
let our_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse().unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
let proto_flags = proto_flags_for_tag(success.proto_tag);
debug!(user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started");
// We need to handle framing here.
// Client sends: [Len:4][Payload...] (Intermediate/Secure)
// We must strip Len and send Payload to ME.
// ME sends: [Payload...]
// We must add [Len:4] and send to Client.
// For Secure mode, Len has padding bit (MSB).
let is_secure = success.proto_tag == crate::protocol::constants::ProtoTag::Secure;
let mut client_closed = false;
let mut server_closed = false;
// Split client_reader/writer to use in select!
// CryptoReader/Writer don't support splitting easily without Arc/Mutex or unsafe,
// but here we are in a loop.
// We can't easily split them because they wrap the underlying stream.
// However, we can use a loop with select! on read and rx.
let mut len_buf = [0u8; 4];
let mut reading_len = true;
let mut current_payload_len = 0;
let mut payload_buf = Vec::new();
let result: Result<()> = loop {
tokio::select! {
// C->S: Read length, then payload
res = async {
if reading_len {
client_reader.read_exact(&mut len_buf).await.map(|_| true)
} else {
// Read payload
// We need to read exactly current_payload_len
if payload_buf.len() < current_payload_len {
let needed = current_payload_len - payload_buf.len();
let mut chunk = vec![0u8; needed];
let n = client_reader.read(&mut chunk).await?;
if n == 0 { return Ok(false); } // EOF
payload_buf.extend_from_slice(&chunk[..n]);
Ok(true)
} else {
Ok(true) // Should not happen
}
}
}, if !client_closed => {
match res {
Ok(true) => {
if reading_len {
// Got length
let raw_len = u32::from_le_bytes(len_buf);
// In secure mode, MSB is padding flag. In intermediate, it's just len.
// But wait, standard intermediate doesn't use MSB for padding.
// Secure mode DOES.
// Let's trust the protocol tag.
let len = if is_secure {
raw_len & 0x7FFFFFFF
} else {
raw_len
};
current_payload_len = len as usize;
// Sanity check
if current_payload_len > 16 * 1024 * 1024 {
debug!(conn_id, len=current_payload_len, "Client sent huge frame");
break Err(ProxyError::Proxy("Frame too large".into()));
}
payload_buf.clear();
payload_buf.reserve(current_payload_len);
reading_len = false;
} else {
// Got some payload data
if payload_buf.len() == current_payload_len {
// Full frame received
trace!(conn_id, bytes = current_payload_len, "C->ME (Frame complete)");
stats.add_user_octets_from(&user, current_payload_len as u64);
// Send to ME
// Note: In secure mode, we send the PADDING bytes too?
// Erlang mtp_intermediate: strips 4 bytes len.
// Erlang mtp_secure: strips 4 bytes len.
// The payload includes the padding if it was added?
// Actually, secure layer (mtp_secure.erl) handles padding removal?
// No, mtp_secure just sets padding=>true for intermediate codec.
// The intermediate codec (mtp_intermediate.erl) just extracts the packet.
// The packet passed to RPC is the payload.
// If secure mode adds random padding at the end, it is part of the payload
// that ME receives?
// Let's look at C code.
// ext-server.c: reads packet_len.
// if (packet_len & 0x80000000) -> has padding.
// It reads the full packet.
// Then it passes it to forward_tcp_query.
// So YES, we send the full payload including padding to ME.
if let Err(e) = me_pool.send_proxy_req(
conn_id, peer, our_addr, &payload_buf, proto_flags
).await {
break Err(e);
}
// Reset for next frame
reading_len = true;
}
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
debug!(conn_id, "Client EOF");
client_closed = true;
if server_closed { break Ok(()); }
// Signal ME to close this connection
let _ = me_pool.send_close(conn_id).await;
}
Ok(false) => {
// EOF
debug!(conn_id, "Client EOF");
client_closed = true;
let _ = me_pool.send_close(conn_id).await;
if server_closed { break Ok(()); }
}
Err(e) => {
debug!(conn_id, error = %e, "Client read error");
break Err(ProxyError::Io(e));
}
}
}
// S->C: ME sends data, we wrap and send to client
me_msg = me_rx.recv(), if !server_closed => {
match me_msg {
Some(MeResponse::Data(data)) => {
trace!(conn_id, bytes = data.len(), "ME->C");
stats.add_user_octets_to(&user, data.len() as u64);
// Wrap in intermediate frame
let len = data.len() as u32;
// For secure mode, we might need to add padding?
// C code: forward_mtproto_packet -> just sends data.
// But wait, C code adds framing in net-tcp-rpc-ext-server.c?
// No, forward_tcp_query sends RPC_PROXY_REQ.
// ME sends RPC_PROXY_ANS.
// The data in ANS is the MTProto packet.
// We need to send it to client.
// If client is Intermediate/Secure, we MUST add the 4-byte length prefix.
// Secure mode: usually we don't ADD padding on response, we just send valid packets.
// But we MUST send the length.
if let Err(e) = client_writer.write_all(&len.to_le_bytes()).await {
break Err(ProxyError::Io(e));
}
Err(e) => {
debug!(conn_id, error = %e, "Client read error");
if let Err(e) = client_writer.write_all(&data).await {
break Err(ProxyError::Io(e));
}
if let Err(e) = client_writer.flush().await {
break Err(ProxyError::Io(e));
}
}
}
// S→C: ME sends response, we forward to client
me_msg = me_rx.recv(), if !server_closed => {
match me_msg {
Some(MeResponse::Data(data)) => {
let mut frame_len = data.len() as u32;
let mut secure_padding = [0u8; 1];
if is_secure && data.len() % 4 == 0 {
frame_len += 1;
secure_padding[0] = 0;
}
trace!(conn_id, bytes = data.len(), frame_len, "ME payload -> C frame");
stats.add_user_octets_to(&user, frame_len as u64 + 4);
if let Err(e) = client_writer.write_all(&frame_len.to_le_bytes()).await {
debug!(conn_id, error = %e, "Client write header error");
break Err(ProxyError::Io(e));
}
if let Err(e) = client_writer.write_all(&data).await {
debug!(conn_id, error = %e, "Client write error");
break Err(ProxyError::Io(e));
}
if frame_len as usize > data.len() {
if let Err(e) = client_writer.write_all(&secure_padding).await {
debug!(conn_id, error = %e, "Client write padding error");
break Err(ProxyError::Io(e));
}
}
if let Err(e) = client_writer.flush().await {
break Err(ProxyError::Io(e));
}
}
Some(MeResponse::Ack(_token)) => {
// QuickACK from ME — could forward to client as obfuscated ACK
// For now, just log
trace!(conn_id, "ME ACK (ignored)");
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent CLOSE");
server_closed = true;
if client_closed { break Ok(()); }
}
None => {
// Channel closed — ME connection died
debug!(conn_id, "ME channel closed");
if client_closed { break Ok(()); }
break Err(ProxyError::Proxy("ME connection lost".into()));
}
Some(MeResponse::Ack(_)) => {
trace!(conn_id, "ME ACK");
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent CLOSE");
server_closed = true;
if client_closed { break Ok(()); }
// We should probably close client connection too
break Ok(());
}
None => {
debug!(conn_id, "ME channel closed");
server_closed = true;
if client_closed { break Ok(()); }
break Err(ProxyError::Proxy("ME connection lost".into()));
}
}
// Both sides closed
else => {
break Ok(());
}
}
};
// Cleanup
debug!(user = %user, conn_id, "ME relay cleanup");
me_pool.registry().unregister(conn_id).await;
stats.decrement_user_curr_connects(&user);
match &result {
Ok(()) => debug!(user = %user, conn_id, "ME relay completed"),
Err(e) => debug!(user = %user, conn_id, error = %e, "ME relay error"),
}
result
}
// =====================================================================
// Helpers
// =====================================================================
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
};
// Cleanup
debug!(user = %user, conn_id, "ME relay cleanup");
me_pool.registry().unregister(conn_id).await;
stats.decrement_user_curr_connects(&user);
result
}
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() });
@@ -613,4 +658,4 @@
))
}
}