770 lines
24 KiB
Rust
770 lines
24 KiB
Rust
//! Encrypted stream wrappers using AES-CTR
|
|
//!
|
|
//! This module provides stateful async stream wrappers that handle
|
|
//! encryption/decryption with proper partial read/write handling.
|
|
//!
|
|
//! Key design principles:
|
|
//! - Explicit state machines for all async operations
|
|
//! - Never lose data on partial reads/writes
|
|
//! - Honest reporting of bytes written (AsyncWrite contract)
|
|
//! - Bounded internal buffers with backpressure
|
|
//!
|
|
//! AES-CTR is a stream cipher: the keystream position must advance exactly by the
|
|
//! number of plaintext bytes that are *accepted* (written or buffered).
|
|
//!
|
|
//! This implementation guarantees:
|
|
//! - CTR state never "drifts"
|
|
//! - never accept plaintext unless we can guarantee that all corresponding ciphertext
|
|
//! is either written to upstream or stored in our pending buffer
|
|
//! - when upstream is pending -> ciphertext is buffered/bounded and backpressure is applied
|
|
//!
|
|
//! =======================
|
|
//! Writer state machine
|
|
//! =======================
|
|
//!
|
|
//! ┌──────────┐ write buf ┌──────────┐
|
|
//! │ Idle │ ---------------> │ Flushing │
|
|
//! │ │ <--------------- │ │
|
|
//! └──────────┘ drained └──────────┘
|
|
//! │ │
|
|
//! │ errors │
|
|
//! ▼ ▼
|
|
//! ┌────────────────────────────────────────┐
|
|
//! │ Poisoned │
|
|
//! └────────────────────────────────────────┘
|
|
//!
|
|
//! Backpressure
|
|
//! - pending ciphertext buffer is bounded (MAX_PENDING_WRITE)
|
|
//! - pending is full and upstream is pending
|
|
//! -> poll_write returns Poll::Pending
|
|
//! -> do not accept any plaintext
|
|
//!
|
|
//! Performance
|
|
//! - fast path when pending is empty: encrypt into scratch and try upstream
|
|
//! - if upstream Pending/partial => move remainder into pending without re-encrypting
|
|
//! - when upstream is Pending but pending still has room: accept `to_accept` bytes and
|
|
//! encrypt+append ciphertext directly into pending (in-place encryption of appended range)
|
|
|
|
//! Encrypted stream wrappers using AES-CTR
|
|
//!
|
|
//! This module provides stateful async stream wrappers that handle
|
|
//! encryption/decryption with proper partial read/write handling.
|
|
|
|
use bytes::{Bytes, BytesMut};
|
|
use std::io::{self, ErrorKind, Result};
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
use tracing::{debug, trace, warn};
|
|
|
|
use crate::crypto::AesCtr;
|
|
use super::state::{StreamState, YieldBuffer};
|
|
|
|
// ============= Constants =============
|
|
|
|
/// Maximum size for pending ciphertext buffer (bounded backpressure).
|
|
/// Reduced to 64KB to prevent bufferbloat on mobile networks.
|
|
/// 512KB was causing high latency on 3G/LTE connections.
|
|
const MAX_PENDING_WRITE: usize = 64 * 1024;
|
|
|
|
/// Default read buffer capacity (reader mostly decrypts in-place into caller buffer).
|
|
const DEFAULT_READ_CAPACITY: usize = 16 * 1024;
|
|
|
|
// ============= CryptoReader State =============
|
|
|
|
#[derive(Debug)]
|
|
enum CryptoReaderState {
|
|
/// Ready to read new data
|
|
Idle,
|
|
|
|
/// Have decrypted data ready to yield to caller
|
|
Yielding { buffer: YieldBuffer },
|
|
|
|
/// Stream encountered an error and cannot be used
|
|
Poisoned { error: Option<io::Error> },
|
|
}
|
|
|
|
impl StreamState for CryptoReaderState {
|
|
fn is_terminal(&self) -> bool {
|
|
matches!(self, Self::Poisoned { .. })
|
|
}
|
|
|
|
fn is_poisoned(&self) -> bool {
|
|
matches!(self, Self::Poisoned { .. })
|
|
}
|
|
|
|
fn state_name(&self) -> &'static str {
|
|
match self {
|
|
Self::Idle => "Idle",
|
|
Self::Yielding { .. } => "Yielding",
|
|
Self::Poisoned { .. } => "Poisoned",
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============= CryptoReader =============
|
|
|
|
/// Reader that decrypts data using AES-CTR with proper state machine.
|
|
pub struct CryptoReader<R> {
|
|
upstream: R,
|
|
decryptor: AesCtr,
|
|
state: CryptoReaderState,
|
|
|
|
/// Reserved for future coalescing optimizations.
|
|
#[allow(dead_code)]
|
|
read_buf: BytesMut,
|
|
}
|
|
|
|
impl<R> CryptoReader<R> {
|
|
pub fn new(upstream: R, decryptor: AesCtr) -> Self {
|
|
Self {
|
|
upstream,
|
|
decryptor,
|
|
state: CryptoReaderState::Idle,
|
|
read_buf: BytesMut::with_capacity(DEFAULT_READ_CAPACITY),
|
|
}
|
|
}
|
|
|
|
pub fn get_ref(&self) -> &R {
|
|
&self.upstream
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> &mut R {
|
|
&mut self.upstream
|
|
}
|
|
|
|
pub fn into_inner(self) -> R {
|
|
self.upstream
|
|
}
|
|
|
|
pub fn is_poisoned(&self) -> bool {
|
|
self.state.is_poisoned()
|
|
}
|
|
|
|
pub fn state_name(&self) -> &'static str {
|
|
self.state.state_name()
|
|
}
|
|
|
|
fn poison(&mut self, error: io::Error) {
|
|
self.state = CryptoReaderState::Poisoned { error: Some(error) };
|
|
}
|
|
|
|
fn take_poison_error(&mut self) -> io::Error {
|
|
match &mut self.state {
|
|
CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
|
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
|
}),
|
|
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<Result<()>> {
|
|
let this = self.get_mut();
|
|
|
|
loop {
|
|
match &mut this.state {
|
|
CryptoReaderState::Poisoned { .. } => {
|
|
let err = this.take_poison_error();
|
|
return Poll::Ready(Err(err));
|
|
}
|
|
|
|
CryptoReaderState::Yielding { buffer } => {
|
|
if buf.remaining() == 0 {
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
let to_copy = buffer.remaining().min(buf.remaining());
|
|
let dst = buf.initialize_unfilled_to(to_copy);
|
|
let copied = buffer.copy_to(dst);
|
|
buf.advance(copied);
|
|
|
|
if buffer.is_empty() {
|
|
this.state = CryptoReaderState::Idle;
|
|
}
|
|
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
CryptoReaderState::Idle => {
|
|
if buf.remaining() == 0 {
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
// Read directly into caller buffer, decrypt in-place for the bytes read.
|
|
let before = buf.filled().len();
|
|
|
|
match Pin::new(&mut this.upstream).poll_read(cx, buf) {
|
|
Poll::Pending => return Poll::Pending,
|
|
|
|
Poll::Ready(Err(e)) => {
|
|
this.poison(io::Error::new(e.kind(), e.to_string()));
|
|
return Poll::Ready(Err(e));
|
|
}
|
|
|
|
Poll::Ready(Ok(())) => {
|
|
let after = buf.filled().len();
|
|
let bytes_read = after - before;
|
|
|
|
if bytes_read == 0 {
|
|
// EOF
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
let filled = buf.filled_mut();
|
|
this.decryptor.apply(&mut filled[before..after]);
|
|
|
|
trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk");
|
|
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncRead + Unpin> CryptoReader<R> {
|
|
/// Read and decrypt exactly n bytes.
|
|
pub async fn read_exact_decrypt(&mut self, n: usize) -> Result<Bytes> {
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
if self.is_poisoned() {
|
|
return Err(self.take_poison_error());
|
|
}
|
|
|
|
let mut result = BytesMut::with_capacity(n);
|
|
|
|
// Drain Yielding buffer if present (rare, kept for completeness)
|
|
if let CryptoReaderState::Yielding { buffer } = &mut self.state {
|
|
let to_take = buffer.remaining().min(n);
|
|
let mut temp = vec![0u8; to_take];
|
|
buffer.copy_to(&mut temp);
|
|
result.extend_from_slice(&temp);
|
|
|
|
if buffer.is_empty() {
|
|
self.state = CryptoReaderState::Idle;
|
|
}
|
|
}
|
|
|
|
while result.len() < n {
|
|
let mut temp = vec![0u8; n - result.len()];
|
|
let read = self.read(&mut temp).await?;
|
|
|
|
if read == 0 {
|
|
return Err(io::Error::new(
|
|
ErrorKind::UnexpectedEof,
|
|
format!("expected {} bytes, got {}", n, result.len()),
|
|
));
|
|
}
|
|
|
|
result.extend_from_slice(&temp[..read]);
|
|
}
|
|
|
|
Ok(result.freeze())
|
|
}
|
|
|
|
/// Read up to max_size bytes, returning decrypted bytes as Bytes.
|
|
pub async fn read_decrypt(&mut self, max_size: usize) -> Result<Bytes> {
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
if self.is_poisoned() {
|
|
return Err(self.take_poison_error());
|
|
}
|
|
|
|
if let CryptoReaderState::Yielding { buffer } = &mut self.state {
|
|
let to_take = buffer.remaining().min(max_size);
|
|
let mut temp = vec![0u8; to_take];
|
|
buffer.copy_to(&mut temp);
|
|
|
|
if buffer.is_empty() {
|
|
self.state = CryptoReaderState::Idle;
|
|
}
|
|
|
|
return Ok(Bytes::from(temp));
|
|
}
|
|
|
|
let mut temp = vec![0u8; max_size];
|
|
let read = self.read(&mut temp).await?;
|
|
|
|
if read == 0 {
|
|
return Ok(Bytes::new());
|
|
}
|
|
|
|
temp.truncate(read);
|
|
Ok(Bytes::from(temp))
|
|
}
|
|
}
|
|
|
|
// ============= Pending Ciphertext =============
|
|
|
|
/// Pending ciphertext buffer with explicit position and strict max size.
|
|
#[derive(Debug)]
|
|
struct PendingCiphertext {
|
|
buf: BytesMut,
|
|
pos: usize,
|
|
max_len: usize,
|
|
}
|
|
|
|
impl PendingCiphertext {
|
|
fn new(max_len: usize) -> Self {
|
|
Self {
|
|
buf: BytesMut::with_capacity(16 * 1024),
|
|
pos: 0,
|
|
max_len,
|
|
}
|
|
}
|
|
|
|
fn pending_len(&self) -> usize {
|
|
self.buf.len().saturating_sub(self.pos)
|
|
}
|
|
|
|
fn is_empty(&self) -> bool {
|
|
self.pending_len() == 0
|
|
}
|
|
|
|
fn pending_slice(&self) -> &[u8] {
|
|
&self.buf[self.pos..]
|
|
}
|
|
|
|
fn remaining_capacity(&self) -> usize {
|
|
self.max_len.saturating_sub(self.buf.len())
|
|
}
|
|
|
|
fn advance(&mut self, n: usize) {
|
|
self.pos = (self.pos + n).min(self.buf.len());
|
|
|
|
if self.pos == self.buf.len() {
|
|
self.buf.clear();
|
|
self.pos = 0;
|
|
return;
|
|
}
|
|
|
|
// Compact when a large prefix was consumed.
|
|
if self.pos >= 16 * 1024 {
|
|
let _ = self.buf.split_to(self.pos);
|
|
self.pos = 0;
|
|
}
|
|
}
|
|
|
|
/// Replace the entire pending ciphertext by moving `src` in (swap, no copy).
|
|
fn replace_with(&mut self, mut src: BytesMut) {
|
|
debug_assert!(src.len() <= self.max_len);
|
|
|
|
self.buf.clear();
|
|
self.pos = 0;
|
|
|
|
// Swap: keep allocations hot and avoid copying bytes.
|
|
std::mem::swap(&mut self.buf, &mut src);
|
|
}
|
|
|
|
/// Append plaintext and encrypt appended range in-place.
|
|
fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> {
|
|
if plaintext.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
if plaintext.len() > self.remaining_capacity() {
|
|
return Err(io::Error::new(
|
|
ErrorKind::WouldBlock,
|
|
"pending ciphertext buffer is full",
|
|
));
|
|
}
|
|
|
|
let start = self.buf.len();
|
|
self.buf.reserve(plaintext.len());
|
|
self.buf.extend_from_slice(plaintext);
|
|
|
|
encryptor.apply(&mut self.buf[start..]);
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ============= CryptoWriter State =============
|
|
|
|
#[derive(Debug)]
|
|
enum CryptoWriterState {
|
|
/// No pending ciphertext buffered.
|
|
Idle,
|
|
|
|
/// There is pending ciphertext to flush.
|
|
Flushing { pending: PendingCiphertext },
|
|
|
|
/// Stream encountered an error and cannot be used
|
|
Poisoned { error: Option<io::Error> },
|
|
}
|
|
|
|
impl StreamState for CryptoWriterState {
|
|
fn is_terminal(&self) -> bool {
|
|
matches!(self, Self::Poisoned { .. })
|
|
}
|
|
|
|
fn is_poisoned(&self) -> bool {
|
|
matches!(self, Self::Poisoned { .. })
|
|
}
|
|
|
|
fn state_name(&self) -> &'static str {
|
|
match self {
|
|
Self::Idle => "Idle",
|
|
Self::Flushing { .. } => "Flushing",
|
|
Self::Poisoned { .. } => "Poisoned",
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============= CryptoWriter =============
|
|
|
|
/// Writer that encrypts data using AES-CTR with correct async semantics.
|
|
pub struct CryptoWriter<W> {
|
|
upstream: W,
|
|
encryptor: AesCtr,
|
|
state: CryptoWriterState,
|
|
scratch: BytesMut,
|
|
}
|
|
|
|
impl<W> CryptoWriter<W> {
|
|
pub fn new(upstream: W, encryptor: AesCtr) -> Self {
|
|
Self {
|
|
upstream,
|
|
encryptor,
|
|
state: CryptoWriterState::Idle,
|
|
scratch: BytesMut::with_capacity(16 * 1024),
|
|
}
|
|
}
|
|
|
|
pub fn get_ref(&self) -> &W {
|
|
&self.upstream
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> &mut W {
|
|
&mut self.upstream
|
|
}
|
|
|
|
pub fn into_inner(self) -> W {
|
|
self.upstream
|
|
}
|
|
|
|
pub fn is_poisoned(&self) -> bool {
|
|
self.state.is_poisoned()
|
|
}
|
|
|
|
pub fn state_name(&self) -> &'static str {
|
|
self.state.state_name()
|
|
}
|
|
|
|
pub fn has_pending(&self) -> bool {
|
|
matches!(self.state, CryptoWriterState::Flushing { .. })
|
|
}
|
|
|
|
pub fn pending_len(&self) -> usize {
|
|
match &self.state {
|
|
CryptoWriterState::Flushing { pending } => pending.pending_len(),
|
|
_ => 0,
|
|
}
|
|
}
|
|
|
|
fn poison(&mut self, error: io::Error) {
|
|
self.state = CryptoWriterState::Poisoned { error: Some(error) };
|
|
}
|
|
|
|
fn take_poison_error(&mut self) -> io::Error {
|
|
match &mut self.state {
|
|
CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
|
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
|
}),
|
|
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
|
}
|
|
}
|
|
|
|
/// Ensure we are in Flushing state and return mutable pending buffer.
|
|
fn ensure_pending<'a>(state: &'a mut CryptoWriterState) -> &'a mut PendingCiphertext {
|
|
if matches!(state, CryptoWriterState::Idle) {
|
|
*state = CryptoWriterState::Flushing {
|
|
pending: PendingCiphertext::new(MAX_PENDING_WRITE),
|
|
};
|
|
}
|
|
|
|
match state {
|
|
CryptoWriterState::Flushing { pending } => pending,
|
|
_ => unreachable!("ensure_pending guarantees Flushing state"),
|
|
}
|
|
}
|
|
|
|
/// Select how many plaintext bytes can be accepted in buffering path
|
|
fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize {
|
|
if buf_len == 0 {
|
|
return 0;
|
|
}
|
|
|
|
match state {
|
|
CryptoWriterState::Flushing { pending } => buf_len.min(pending.remaining_capacity()),
|
|
CryptoWriterState::Idle => buf_len.min(MAX_PENDING_WRITE),
|
|
CryptoWriterState::Poisoned { .. } => 0,
|
|
}
|
|
}
|
|
|
|
/// Encrypt plaintext into scratch (CTR advances by plaintext.len()).
|
|
fn encrypt_into_scratch(encryptor: &mut AesCtr, scratch: &mut BytesMut, plaintext: &[u8]) {
|
|
scratch.clear();
|
|
scratch.reserve(plaintext.len());
|
|
scratch.extend_from_slice(plaintext);
|
|
encryptor.apply(&mut scratch[..]);
|
|
}
|
|
}
|
|
|
|
impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
|
|
/// Flush as much pending ciphertext as possible
|
|
fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
loop {
|
|
match &mut self.state {
|
|
CryptoWriterState::Poisoned { .. } => {
|
|
let err = self.take_poison_error();
|
|
return Poll::Ready(Err(err));
|
|
}
|
|
|
|
CryptoWriterState::Idle => return Poll::Ready(Ok(())),
|
|
|
|
CryptoWriterState::Flushing { pending } => {
|
|
if pending.is_empty() {
|
|
self.state = CryptoWriterState::Idle;
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
let data = pending.pending_slice();
|
|
|
|
match Pin::new(&mut self.upstream).poll_write(cx, data) {
|
|
Poll::Pending => {
|
|
trace!(
|
|
pending_len = pending.pending_len(),
|
|
pending_cap = pending.remaining_capacity(),
|
|
"CryptoWriter: upstream Pending while flushing pending ciphertext"
|
|
);
|
|
return Poll::Pending;
|
|
}
|
|
|
|
Poll::Ready(Err(e)) => {
|
|
self.poison(io::Error::new(e.kind(), e.to_string()));
|
|
return Poll::Ready(Err(e));
|
|
}
|
|
|
|
Poll::Ready(Ok(0)) => {
|
|
let err = io::Error::new(
|
|
ErrorKind::WriteZero,
|
|
"upstream returned 0 bytes written",
|
|
);
|
|
self.poison(io::Error::new(err.kind(), err.to_string()));
|
|
return Poll::Ready(Err(err));
|
|
}
|
|
|
|
Poll::Ready(Ok(n)) => {
|
|
pending.advance(n);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<Result<usize>> {
|
|
let this = self.get_mut();
|
|
|
|
// Poisoned?
|
|
if matches!(this.state, CryptoWriterState::Poisoned { .. }) {
|
|
let err = this.take_poison_error();
|
|
return Poll::Ready(Err(err));
|
|
}
|
|
|
|
// Empty write is always OK
|
|
if buf.is_empty() {
|
|
return Poll::Ready(Ok(0));
|
|
}
|
|
|
|
// 1) If we have pending ciphertext, prioritize flushing it
|
|
if matches!(this.state, CryptoWriterState::Flushing { .. }) {
|
|
match this.poll_flush_pending(cx) {
|
|
Poll::Ready(Ok(())) => {
|
|
// pending drained -> proceed
|
|
}
|
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
|
Poll::Pending => {
|
|
// Upstream blocked. Apply ideal backpressure
|
|
let to_accept =
|
|
Self::select_to_accept_for_buffering(&this.state, buf.len());
|
|
|
|
if to_accept == 0 {
|
|
trace!(
|
|
buf_len = buf.len(),
|
|
pending_len = this.pending_len(),
|
|
"CryptoWriter backpressure: pending full and upstream Pending -> Pending"
|
|
);
|
|
return Poll::Pending;
|
|
}
|
|
|
|
let plaintext = &buf[..to_accept];
|
|
|
|
// Disjoint borrows
|
|
let encryptor = &mut this.encryptor;
|
|
let pending = Self::ensure_pending(&mut this.state);
|
|
|
|
if let Err(e) = pending.push_encrypted(encryptor, plaintext) {
|
|
if e.kind() == ErrorKind::WouldBlock {
|
|
return Poll::Pending;
|
|
}
|
|
return Poll::Ready(Err(e));
|
|
}
|
|
|
|
return Poll::Ready(Ok(to_accept));
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2) Fast path: pending empty -> write-through
|
|
debug_assert!(matches!(this.state, CryptoWriterState::Idle));
|
|
|
|
let to_accept = buf.len().min(MAX_PENDING_WRITE);
|
|
let plaintext = &buf[..to_accept];
|
|
|
|
Self::encrypt_into_scratch(&mut this.encryptor, &mut this.scratch, plaintext);
|
|
|
|
match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) {
|
|
Poll::Pending => {
|
|
// Upstream blocked: buffer FULL ciphertext for accepted bytes.
|
|
let ciphertext = std::mem::take(&mut this.scratch);
|
|
|
|
let pending = Self::ensure_pending(&mut this.state);
|
|
pending.replace_with(ciphertext);
|
|
|
|
Poll::Ready(Ok(to_accept))
|
|
}
|
|
|
|
Poll::Ready(Err(e)) => {
|
|
this.poison(io::Error::new(e.kind(), e.to_string()));
|
|
Poll::Ready(Err(e))
|
|
}
|
|
|
|
Poll::Ready(Ok(0)) => {
|
|
let err = io::Error::new(ErrorKind::WriteZero, "upstream returned 0 bytes written");
|
|
this.poison(io::Error::new(err.kind(), err.to_string()));
|
|
Poll::Ready(Err(err))
|
|
}
|
|
|
|
Poll::Ready(Ok(n)) => {
|
|
if n == this.scratch.len() {
|
|
this.scratch.clear();
|
|
return Poll::Ready(Ok(to_accept));
|
|
}
|
|
|
|
// Partial upstream write of ciphertext
|
|
let remainder = this.scratch.split_off(n);
|
|
this.scratch.clear();
|
|
|
|
let pending = Self::ensure_pending(&mut this.state);
|
|
pending.replace_with(remainder);
|
|
|
|
Poll::Ready(Ok(to_accept))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
let this = self.get_mut();
|
|
|
|
if matches!(this.state, CryptoWriterState::Poisoned { .. }) {
|
|
let err = this.take_poison_error();
|
|
return Poll::Ready(Err(err));
|
|
}
|
|
|
|
match this.poll_flush_pending(cx) {
|
|
Poll::Pending => return Poll::Pending,
|
|
Poll::Ready(Ok(())) => {}
|
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
|
}
|
|
|
|
Pin::new(&mut this.upstream).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
let this = self.get_mut();
|
|
|
|
// Best-effort flush pending ciphertext before shutdown
|
|
match this.poll_flush_pending(cx) {
|
|
Poll::Pending => {
|
|
debug!(
|
|
pending_len = this.pending_len(),
|
|
"CryptoWriter: shutdown with pending ciphertext (upstream Pending)"
|
|
);
|
|
}
|
|
Poll::Ready(Err(_)) => {}
|
|
Poll::Ready(Ok(())) => {}
|
|
}
|
|
|
|
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
// ============= PassthroughStream =============
|
|
|
|
/// Passthrough stream for fast mode - no encryption/decryption
|
|
pub struct PassthroughStream<S> {
|
|
inner: S,
|
|
}
|
|
|
|
impl<S> PassthroughStream<S> {
|
|
pub fn new(inner: S) -> Self {
|
|
Self { inner }
|
|
}
|
|
|
|
pub fn get_ref(&self) -> &S {
|
|
&self.inner
|
|
}
|
|
|
|
pub fn get_mut(&mut self) -> &mut S {
|
|
&mut self.inner
|
|
}
|
|
|
|
pub fn into_inner(self) -> S {
|
|
self.inner
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncRead + Unpin> AsyncRead for PassthroughStream<S> {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<Result<()>> {
|
|
Pin::new(&mut self.inner).poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<Result<usize>> {
|
|
Pin::new(&mut self.inner).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
Pin::new(&mut self.inner).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
Pin::new(&mut self.inner).poll_shutdown(cx)
|
|
}
|
|
} |