//! tokio-util codec integration for MTProto frames //! //! This module provides Encoder/Decoder implementations compatible //! with tokio-util's Framed wrapper for easy async frame I/O. use bytes::{Bytes, BytesMut, BufMut}; use std::io::{self, Error, ErrorKind}; use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use crate::protocol::constants::ProtoTag; use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; // ============= Unified Codec ============= /// Unified frame codec that wraps all protocol variants /// /// This codec implements tokio-util's Encoder and Decoder traits, /// allowing it to be used with `Framed` for async frame I/O. pub struct FrameCodec { /// Protocol variant proto_tag: ProtoTag, /// Maximum allowed frame size max_frame_size: usize, /// RNG for secure padding rng: Arc, } impl FrameCodec { /// Create a new codec for the given protocol pub fn new(proto_tag: ProtoTag, rng: Arc) -> Self { Self { proto_tag, max_frame_size: 16 * 1024 * 1024, // 16MB default rng, } } /// Set maximum frame size pub fn with_max_frame_size(mut self, size: usize) -> Self { self.max_frame_size = size; self } /// Get protocol tag pub fn proto_tag(&self) -> ProtoTag { self.proto_tag } } impl Decoder for FrameCodec { type Item = Frame; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match self.proto_tag { ProtoTag::Abridged => decode_abridged(src, self.max_frame_size), ProtoTag::Intermediate => decode_intermediate(src, self.max_frame_size), ProtoTag::Secure => decode_secure(src, self.max_frame_size), } } } impl Encoder for FrameCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { match self.proto_tag { ProtoTag::Abridged => encode_abridged(&frame, dst), ProtoTag::Intermediate => encode_intermediate(&frame, dst), ProtoTag::Secure => encode_secure(&frame, dst, &self.rng), } } } // ============= Abridged Protocol ============= fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result> { if src.is_empty() { return Ok(None); } let mut meta = FrameMeta::new(); let first_byte = src[0]; // Extract length and quickack flag let mut len_words = (first_byte & 0x7f) as usize; if first_byte >= 0x80 { meta.quickack = true; } let header_len; if len_words == 0x7f { // Extended length (3 more bytes needed) if src.len() < 4 { return Ok(None); } len_words = u32::from_le_bytes([src[1], src[2], src[3], 0]) as usize; header_len = 4; } else { header_len = 1; } // Length is in 4-byte words let byte_len = len_words.checked_mul(4).ok_or_else(|| { Error::new(ErrorKind::InvalidData, "frame length overflow") })?; // Validate size if byte_len > max_size { return Err(Error::new( ErrorKind::InvalidData, format!("frame too large: {} bytes (max {})", byte_len, max_size) )); } let total_len = header_len + byte_len; if src.len() < total_len { // Reserve space for the rest of the frame src.reserve(total_len - src.len()); return Ok(None); } // Extract data let _ = src.split_to(header_len); let data = src.split_to(byte_len).freeze(); Ok(Some(Frame::with_meta(data, meta))) } fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { let data = &frame.data; // Validate alignment if data.len() % 4 != 0 { return Err(Error::new( ErrorKind::InvalidInput, format!("abridged frame must be 4-byte aligned, got {} bytes", data.len()) )); } // Simple ACK: send reversed data without header if frame.meta.simple_ack { dst.reserve(data.len()); for byte in data.iter().rev() { dst.put_u8(*byte); } return Ok(()); } let len_words = data.len() / 4; if len_words < 0x7f { // Short header dst.reserve(1 + data.len()); let mut len_byte = len_words as u8; if frame.meta.quickack { len_byte |= 0x80; } dst.put_u8(len_byte); } else if len_words < (1 << 24) { // Extended header dst.reserve(4 + data.len()); let mut first = 0x7fu8; if frame.meta.quickack { first |= 0x80; } dst.put_u8(first); let len_bytes = (len_words as u32).to_le_bytes(); dst.extend_from_slice(&len_bytes[..3]); } else { return Err(Error::new( ErrorKind::InvalidInput, format!("frame too large: {} bytes", data.len()) )); } dst.extend_from_slice(data); Ok(()) } // ============= Intermediate Protocol ============= fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result> { if src.len() < 4 { return Ok(None); } let mut meta = FrameMeta::new(); let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize; // Check QuickACK flag if len >= 0x80000000 { meta.quickack = true; len -= 0x80000000; } // Validate size if len > max_size { return Err(Error::new( ErrorKind::InvalidData, format!("frame too large: {} bytes (max {})", len, max_size) )); } let total_len = 4 + len; if src.len() < total_len { src.reserve(total_len - src.len()); return Ok(None); } // Extract data let _ = src.split_to(4); let data = src.split_to(len).freeze(); Ok(Some(Frame::with_meta(data, meta))) } fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { let data = &frame.data; // Simple ACK: just send data if frame.meta.simple_ack { dst.reserve(data.len()); dst.extend_from_slice(data); return Ok(()); } dst.reserve(4 + data.len()); let mut len = data.len() as u32; if frame.meta.quickack { len |= 0x80000000; } dst.extend_from_slice(&len.to_le_bytes()); dst.extend_from_slice(data); Ok(()) } // ============= Secure Intermediate Protocol ============= fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result> { if src.len() < 4 { return Ok(None); } let mut meta = FrameMeta::new(); let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize; // Check QuickACK flag if len >= 0x80000000 { meta.quickack = true; len -= 0x80000000; } // Validate size if len > max_size { return Err(Error::new( ErrorKind::InvalidData, format!("frame too large: {} bytes (max {})", len, max_size) )); } let total_len = 4 + len; if src.len() < total_len { src.reserve(total_len - src.len()); return Ok(None); } // Calculate padding (indicated by length not divisible by 4) let padding_len = len % 4; let data_len = if padding_len != 0 { len - padding_len } else { len }; meta.padding_len = padding_len as u8; // Extract data (excluding padding) let _ = src.split_to(4); let all_data = src.split_to(len); // Copy only the data portion, excluding padding let data = Bytes::copy_from_slice(&all_data[..data_len]); Ok(Some(Frame::with_meta(data, meta))) } fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> { let data = &frame.data; // Simple ACK: just send data if frame.meta.simple_ack { dst.reserve(data.len()); dst.extend_from_slice(data); return Ok(()); } // Generate padding to make length not divisible by 4 let padding_len = if data.len() % 4 == 0 { // Add 1-3 bytes to make it non-aligned (rng.range(3) + 1) as usize } else { // Already non-aligned, can add 0-3 rng.range(4) as usize }; let total_len = data.len() + padding_len; dst.reserve(4 + total_len); let mut len = total_len as u32; if frame.meta.quickack { len |= 0x80000000; } dst.extend_from_slice(&len.to_le_bytes()); dst.extend_from_slice(data); if padding_len > 0 { let padding = rng.bytes(padding_len); dst.extend_from_slice(&padding); } Ok(()) } // ============= Typed Codecs ============= /// Abridged protocol codec pub struct AbridgedCodec { max_frame_size: usize, } impl AbridgedCodec { pub fn new() -> Self { Self { max_frame_size: 16 * 1024 * 1024, } } } impl Default for AbridgedCodec { fn default() -> Self { Self::new() } } impl Decoder for AbridgedCodec { type Item = Frame; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { decode_abridged(src, self.max_frame_size) } } impl Encoder for AbridgedCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { encode_abridged(&frame, dst) } } impl FrameCodecTrait for AbridgedCodec { fn proto_tag(&self) -> ProtoTag { ProtoTag::Abridged } fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { let before = dst.len(); encode_abridged(frame, dst)?; Ok(dst.len() - before) } fn decode(&self, src: &mut BytesMut) -> io::Result> { decode_abridged(src, self.max_frame_size) } fn min_header_size(&self) -> usize { 1 } } /// Intermediate protocol codec pub struct IntermediateCodec { max_frame_size: usize, } impl IntermediateCodec { pub fn new() -> Self { Self { max_frame_size: 16 * 1024 * 1024, } } } impl Default for IntermediateCodec { fn default() -> Self { Self::new() } } impl Decoder for IntermediateCodec { type Item = Frame; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { decode_intermediate(src, self.max_frame_size) } } impl Encoder for IntermediateCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { encode_intermediate(&frame, dst) } } impl FrameCodecTrait for IntermediateCodec { fn proto_tag(&self) -> ProtoTag { ProtoTag::Intermediate } fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { let before = dst.len(); encode_intermediate(frame, dst)?; Ok(dst.len() - before) } fn decode(&self, src: &mut BytesMut) -> io::Result> { decode_intermediate(src, self.max_frame_size) } fn min_header_size(&self) -> usize { 4 } } /// Secure Intermediate protocol codec pub struct SecureCodec { max_frame_size: usize, rng: Arc, } impl SecureCodec { pub fn new(rng: Arc) -> Self { Self { max_frame_size: 16 * 1024 * 1024, rng, } } } impl Default for SecureCodec { fn default() -> Self { Self::new(Arc::new(SecureRandom::new())) } } impl Decoder for SecureCodec { type Item = Frame; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { decode_secure(src, self.max_frame_size) } } impl Encoder for SecureCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { encode_secure(&frame, dst, &self.rng) } } impl FrameCodecTrait for SecureCodec { fn proto_tag(&self) -> ProtoTag { ProtoTag::Secure } fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { let before = dst.len(); encode_secure(frame, dst, &self.rng)?; Ok(dst.len() - before) } fn decode(&self, src: &mut BytesMut) -> io::Result> { decode_secure(src, self.max_frame_size) } fn min_header_size(&self) -> usize { 4 } } // ============= Tests ============= #[cfg(test)] mod tests { use super::*; use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; use crate::crypto::SecureRandom; use std::sync::Arc; #[tokio::test] async fn test_framed_abridged() { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, AbridgedCodec::new()); let mut reader = FramedRead::new(server, AbridgedCodec::new()); // Write a frame let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8])); writer.send(frame).await.unwrap(); // Read it back let received = reader.next().await.unwrap().unwrap(); assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]); } #[tokio::test] async fn test_framed_intermediate() { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, IntermediateCodec::new()); let mut reader = FramedRead::new(server, IntermediateCodec::new()); let frame = Frame::new(Bytes::from_static(b"hello world")); writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); assert_eq!(&received.data[..], b"hello world"); } #[tokio::test] async fn test_framed_secure() { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new()))); let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new()))); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let frame = Frame::new(original.clone()); writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); assert_eq!(&received.data[..], &original[..]); } #[tokio::test] async fn test_unified_codec() { for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); // Use 4-byte aligned data for abridged compatibility let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let frame = Frame::new(original.clone()); writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); assert_eq!(received.data.len(), 8); } } #[tokio::test] async fn test_multiple_frames() { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, IntermediateCodec::new()); let mut reader = FramedRead::new(server, IntermediateCodec::new()); // Send multiple frames for i in 0..10 { let data: Vec = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect(); let frame = Frame::new(Bytes::from(data)); writer.send(frame).await.unwrap(); } // Receive them for i in 0..10 { let received = reader.next().await.unwrap().unwrap(); assert_eq!(received.data.len(), (i + 1) * 10); } } #[tokio::test] async fn test_quickack_flag() { let (client, server) = duplex(4096); let mut writer = FramedWrite::new(client, IntermediateCodec::new()); let mut reader = FramedRead::new(server, IntermediateCodec::new()); let frame = Frame::quickack(Bytes::from_static(b"urgent")); writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); assert!(received.meta.quickack); } #[test] fn test_frame_too_large() { let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new())) .with_max_frame_size(100); // Create a "frame" that claims to be very large let mut buf = BytesMut::new(); buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000 buf.extend_from_slice(&[0u8; 10]); // partial data let result = codec.decode(&mut buf); assert!(result.is_err()); } }