RPC Flags Fixes

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-13 14:28:47 +03:00
parent f1c1f42de8
commit e62b41ae64
3 changed files with 236 additions and 106 deletions

View File

@@ -14,7 +14,7 @@
use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager};
use crate::transport::middle_proxy::{MePool, MeResponse};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::{AesCtr, SecureRandom};
@@ -373,42 +373,58 @@
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
debug!(user = %user, conn_id, "ME relay started");
let proto_flags = proto_flags_for_tag(Some(success.proto_tag as u32));
let is_secure = matches!(success.proto_tag, ProtoTag::Secure);
debug!(user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started");
// Bidirectional relay loop: client ↔ ME pool
//
// C→S direction: read raw bytes from client_reader, wrap in RPC_PROXY_REQ, send via ME
// S→C direction: receive MeResponse::Data from registry channel, write to client_writer
// C→S direction: read intermediate frame from client, send payload via RPC_PROXY_REQ
// S→C direction: receive raw payload from ME, add client intermediate framing
//
// We use tokio::select! to handle both directions concurrently.
// Unlike direct mode (copy_bidirectional on two TCP streams),
// here one side is a channel (mpsc::Receiver), not a stream.
let mut client_buf = vec![0u8; 64 * 1024];
let mut client_closed = false;
let mut server_closed = false;
let result: Result<()> = loop {
tokio::select! {
// C→S: client sends data, we forward to ME
read_result = client_reader.read(&mut client_buf), if !client_closed => {
// C→S: client sends one intermediate frame, we forward payload to ME
read_result = async {
let mut len_buf = [0u8; 4];
client_reader.read_exact(&mut len_buf).await?;
let raw_len = u32::from_le_bytes(len_buf);
let payload_len = (raw_len & 0x7fff_ffff) as usize;
if payload_len > MAX_MSG_LEN {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("client frame too large: {}", payload_len),
));
}
let mut payload = vec![0u8; payload_len];
client_reader.read_exact(&mut payload).await?;
Ok::<Vec<u8>, std::io::Error>(payload)
}, if !client_closed => {
match read_result {
Ok(0) => {
Ok(payload) => {
trace!(conn_id, bytes = payload.len(), "C frame -> ME payload");
stats.add_user_octets_from(&user, payload.len() as u64);
if let Err(e) = me_pool.send_proxy_req(
conn_id, peer, our_addr, &payload, proto_flags
).await {
break Err(e);
}
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
debug!(conn_id, "Client EOF");
client_closed = true;
if server_closed { break Ok(()); }
// Signal ME to close this connection
let _ = me_pool.send_close(conn_id).await;
}
Ok(n) => {
trace!(conn_id, bytes = n, "C→ME");
stats.add_user_octets_from(&user, n as u64);
if let Err(e) = me_pool.send_proxy_req(
conn_id, peer, our_addr, &client_buf[..n]
).await {
break Err(e);
}
}
Err(e) => {
debug!(conn_id, error = %e, "Client read error");
break Err(ProxyError::Io(e));
@@ -420,12 +436,30 @@
me_msg = me_rx.recv(), if !server_closed => {
match me_msg {
Some(MeResponse::Data(data)) => {
trace!(conn_id, bytes = data.len(), "ME→C");
stats.add_user_octets_to(&user, data.len() as u64);
let mut frame_len = data.len() as u32;
let mut secure_padding = [0u8; 1];
if is_secure && data.len() % 4 == 0 {
frame_len += 1;
secure_padding[0] = 0;
}
trace!(conn_id, bytes = data.len(), frame_len, "ME payload -> C frame");
stats.add_user_octets_to(&user, frame_len as u64 + 4);
if let Err(e) = client_writer.write_all(&frame_len.to_le_bytes()).await {
debug!(conn_id, error = %e, "Client write header error");
break Err(ProxyError::Io(e));
}
if let Err(e) = client_writer.write_all(&data).await {
debug!(conn_id, error = %e, "Client write error");
break Err(ProxyError::Io(e));
}
if frame_len as usize > data.len() {
if let Err(e) = client_writer.write_all(&secure_padding).await {
debug!(conn_id, error = %e, "Client write padding error");
break Err(ProxyError::Io(e));
}
}
if let Err(e) = client_writer.flush().await {
break Err(ProxyError::Io(e));
}
@@ -443,7 +477,6 @@
None => {
// Channel closed — ME connection died
debug!(conn_id, "ME channel closed");
server_closed = true;
if client_closed { break Ok(()); }
break Err(ProxyError::Proxy("ME connection lost".into()));
}
@@ -580,4 +613,4 @@
))
}
}