diff --git a/README.md b/README.md index deabde6..02c554e 100644 --- a/README.md +++ b/README.md @@ -32,3 +32,11 @@ See [examples][examples-url]. ### WebAssembly The `websock-wasm-demo` crate contains a small browser demo that connects to an echo server. + +## Benchmarking + +Criterion benchmarks are available for `websock-mux-proto`. + +```bash +cargo bench -p websock-mux-proto +``` diff --git a/devcert/generate.ps1 b/devcert/generate.ps1 new file mode 100644 index 0000000..4b0d289 --- /dev/null +++ b/devcert/generate.ps1 @@ -0,0 +1,76 @@ +Set-StrictMode -Version Latest +$ErrorActionPreference = "Stop" + +$scriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +Set-Location $scriptDir + +$certName = "localhost" +$days = 10 +$conf = "openssl.conf" + +Write-Host "==> Generating self-signed certificate ($certName)" +Write-Host " validity: $days days" +Write-Host " config: $conf" + +$openssl = if ($env:OPENSSL_BIN) { $env:OPENSSL_BIN } else { "openssl" } +if (-not (Get-Command $openssl -ErrorAction SilentlyContinue)) { + throw "openssl command not found. Set OPENSSL_BIN to openssl.exe path or add openssl to PATH." +} + +# Generate ECDSA P-256 private key +& $openssl ecparam ` + -genkey ` + -name prime256v1 ` + -out "$certName.key" + +if ($LASTEXITCODE -ne 0) { + throw "Failed to generate private key." +} + +# Generate self-signed certificate +& $openssl req ` + -x509 ` + -sha256 ` + -nodes ` + -days "$days" ` + -key "$certName.key" ` + -out "$certName.crt" ` + -config "$conf" ` + -extensions v3_req + +if ($LASTEXITCODE -ne 0) { + throw "Failed to generate certificate." +} + +# Generate raw SHA-256 hash (DER -> hex, no colons) +$derPath = "$certName.der" +& $openssl x509 ` + -in "$certName.crt" ` + -outform der ` + -out $derPath + +if ($LASTEXITCODE -ne 0) { + throw "Failed to export certificate in DER format." +} + +$hex = (Get-FileHash -Path $derPath -Algorithm SHA256).Hash.ToLowerInvariant() +Set-Content -Path "$certName.hex" -Value $hex -NoNewline +Remove-Item $derPath -Force + +# Also print human-readable fingerprint +& $openssl x509 ` + -in "$certName.crt" ` + -noout ` + -fingerprint ` + -sha256 ` + > "$certName.fingerprint" + +if ($LASTEXITCODE -ne 0) { + throw "Failed to generate fingerprint." +} + +Write-Host "==> Done" +Write-Host " - $certName.crt" +Write-Host " - $certName.key" +Write-Host " - $certName.hex (for serverCertificateHashes)" +Write-Host " - $certName.fingerprint" diff --git a/websock-mux-proto/Cargo.toml b/websock-mux-proto/Cargo.toml index cb83977..92b2cd8 100644 --- a/websock-mux-proto/Cargo.toml +++ b/websock-mux-proto/Cargo.toml @@ -17,3 +17,10 @@ websock-proto = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { version = "1", features = ["io-util"] } + +[dev-dependencies] +criterion = "0.8" + +[[bench]] +name = "proto_bench" +harness = false diff --git a/websock-mux-proto/benches/proto_bench.rs b/websock-mux-proto/benches/proto_bench.rs new file mode 100644 index 0000000..1039a2a --- /dev/null +++ b/websock-mux-proto/benches/proto_bench.rs @@ -0,0 +1,122 @@ +use bytes::{Bytes, BytesMut}; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use std::io::Cursor; +use std::hint::black_box; +use websock_mux_proto::{Frame, StreamDir, StreamId, VarInt}; + +fn bench_varint_encode(c: &mut Criterion) { + let values: [u64; 6] = [0, 63, 64, 16_383, 16_384, (1u64 << 62) - 1]; + let mut group = c.benchmark_group("varint_encode"); + + for &v in &values { + group.bench_with_input(BenchmarkId::from_parameter(v), &v, |b, &v| { + b.iter(|| { + let mut buf = BytesMut::with_capacity(8); + VarInt::from_u64(v) + .expect("valid varint value") + .encode(&mut buf); + black_box(buf); + }); + }); + } + + group.finish(); +} + +fn bench_varint_decode(c: &mut Criterion) { + let values: [u64; 6] = [0, 63, 64, 16_383, 16_384, (1u64 << 62) - 1]; + let mut group = c.benchmark_group("varint_decode"); + + for &v in &values { + let mut encoded = BytesMut::with_capacity(8); + VarInt::from_u64(v) + .expect("valid varint value") + .encode(&mut encoded); + let encoded = encoded.freeze(); + + group.bench_with_input(BenchmarkId::from_parameter(v), &encoded, |b, encoded| { + b.iter(|| { + let mut cur = Cursor::new(encoded.clone()); + black_box(VarInt::decode(&mut cur).expect("decode succeeds")); + }); + }); + } + + group.finish(); +} + +fn bench_frame_encode(c: &mut Criterion) { + let id = StreamId::new(7, false, StreamDir::Bi).expect("stream id"); + let small = Frame::Stream { + id, + data: Bytes::from_static(b"hello"), + fin: false, + }; + let large = Frame::Stream { + id, + data: Bytes::from(vec![7u8; 64 * 1024]), + fin: true, + }; + + let mut group = c.benchmark_group("frame_encode"); + group.throughput(Throughput::Bytes(small.encoded_len() as u64)); + group.bench_function("stream_small", |b| { + b.iter(|| { + black_box(small.encode()); + }); + }); + + group.throughput(Throughput::Bytes(large.encoded_len() as u64)); + group.bench_function("stream_large_64k", |b| { + b.iter(|| { + black_box(large.encode()); + }); + }); + group.finish(); +} + +fn bench_frame_decode(c: &mut Criterion) { + let id = StreamId::new(7, false, StreamDir::Bi).expect("stream id"); + let small = Frame::Stream { + id, + data: Bytes::from_static(b"hello"), + fin: false, + } + .encode() + .freeze(); + + let large = Frame::Stream { + id, + data: Bytes::from(vec![7u8; 64 * 1024]), + fin: true, + } + .encode() + .freeze(); + + let mut group = c.benchmark_group("frame_decode"); + group.throughput(Throughput::Bytes(small.len() as u64)); + group.bench_function("stream_small", |b| { + b.iter(|| { + let mut cur = Cursor::new(small.clone()); + black_box(Frame::decode(&mut cur).expect("decode succeeds")); + }); + }); + + group.throughput(Throughput::Bytes(large.len() as u64)); + group.bench_function("stream_large_64k", |b| { + b.iter(|| { + let mut cur = Cursor::new(large.clone()); + black_box(Frame::decode(&mut cur).expect("decode succeeds")); + }); + }); + group.finish(); +} + +criterion_group!( + benches, + bench_varint_encode, + bench_varint_decode, + bench_frame_encode, + bench_frame_decode, +); +criterion_main!(benches); diff --git a/websock-mux-proto/src/stream.rs b/websock-mux-proto/src/stream.rs index 74886c0..b53f555 100644 --- a/websock-mux-proto/src/stream.rs +++ b/websock-mux-proto/src/stream.rs @@ -70,8 +70,29 @@ pub enum Frame { } impl Frame { + /// Return the exact encoded frame length in bytes. + /// + /// This is primarily useful for pre-allocating encode buffers. + pub fn encoded_len(&self) -> usize { + match self { + Frame::OpenUni { id } | Frame::OpenBi { id } => 1 + VarInt(id.0).size(), + Frame::Stream { id, data, fin } => { + 1 + VarInt(id.0).size() + + VarInt(u64::from(*fin)).size() + + VarInt(data.len() as u64).size() + + data.len() + } + Frame::ResetStream { id, code } | Frame::StopSending { id, code } => { + 1 + VarInt(id.0).size() + VarInt(*code).size() + } + Frame::ConnectionClose { code, reason } => { + 1 + VarInt(*code).size() + VarInt(reason.len() as u64).size() + reason.len() + } + } + } + pub fn encode(&self) -> BytesMut { - let mut buf = BytesMut::new(); + let mut buf = BytesMut::with_capacity(self.encoded_len()); match self { Frame::OpenUni { id } => { VarInt(0).encode(&mut buf); @@ -124,13 +145,8 @@ impl Frame { if buf.remaining() < len { return Err(FrameDecodeError::UnexpectedEnd); } - let mut data = vec![0u8; len]; - buf.copy_to_slice(&mut data); - Ok(Frame::Stream { - id, - data: Bytes::from(data), - fin, - }) + let data = buf.copy_to_bytes(len); + Ok(Frame::Stream { id, data, fin }) } 3 => Ok(Frame::ResetStream { id: StreamId(VarInt::decode(buf)?.into_inner()), @@ -146,9 +162,10 @@ impl Frame { if buf.remaining() < len { return Err(FrameDecodeError::UnexpectedEnd); } - let mut data = vec![0u8; len]; - buf.copy_to_slice(&mut data); - let reason = String::from_utf8(data).map_err(|_| FrameDecodeError::InvalidUtf8)?; + let data = buf.copy_to_bytes(len); + let reason = std::str::from_utf8(data.as_ref()) + .map_err(|_| FrameDecodeError::InvalidUtf8)? + .to_owned(); Ok(Frame::ConnectionClose { code, reason }) } _ => Err(FrameDecodeError::UnknownTag(tag)), diff --git a/websock-mux-proto/src/varint.rs b/websock-mux-proto/src/varint.rs index cc508b7..39b763c 100644 --- a/websock-mux-proto/src/varint.rs +++ b/websock-mux-proto/src/varint.rs @@ -22,6 +22,10 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub struct VarInt(pub(crate) u64); impl VarInt { + const MAX_1BYTE: u64 = (1 << 6) - 1; + const MAX_2BYTE: u64 = (1 << 14) - 1; + const MAX_4BYTE: u64 = (1 << 30) - 1; + /// The largest representable value. pub const MAX: Self = Self((1 << 62) - 1); /// The largest encoded value length. @@ -34,7 +38,7 @@ impl VarInt { /// Succeeds if `x` < 2^62. pub fn from_u64(x: u64) -> Result { - if x < 2u64.pow(62) { + if x <= Self::MAX.0 { Ok(Self(x)) } else { Err(VarIntBoundsExceeded) @@ -58,13 +62,13 @@ impl VarInt { /// Compute the number of bytes needed to encode this value. pub fn size(self) -> usize { let x = self.0; - if x < 2u64.pow(6) { + if x <= Self::MAX_1BYTE { 1 - } else if x < 2u64.pow(14) { + } else if x <= Self::MAX_2BYTE { 2 - } else if x < 2u64.pow(30) { + } else if x <= Self::MAX_4BYTE { 4 - } else if x < 2u64.pow(62) { + } else if x <= Self::MAX.0 { 8 } else { unreachable!("malformed VarInt"); @@ -137,32 +141,38 @@ impl VarInt { if !r.has_remaining() { return Err(VarIntUnexpectedEnd); } - let mut buf = [0; 8]; - buf[0] = r.get_u8(); - let tag = buf[0] >> 6; - buf[0] &= 0b0011_1111; + let first = r.get_u8(); + let tag = first >> 6; + let body = (first & 0b0011_1111) as u64; let x = match tag { - 0b00 => u64::from(buf[0]), + 0b00 => body, 0b01 => { if r.remaining() < 1 { return Err(VarIntUnexpectedEnd); } - r.copy_to_slice(&mut buf[1..2]); - u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) + (body << 8) | u64::from(r.get_u8()) } 0b10 => { if r.remaining() < 3 { return Err(VarIntUnexpectedEnd); } - r.copy_to_slice(&mut buf[1..4]); - u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) + (body << 24) + | (u64::from(r.get_u8()) << 16) + | (u64::from(r.get_u8()) << 8) + | u64::from(r.get_u8()) } 0b11 => { if r.remaining() < 7 { return Err(VarIntUnexpectedEnd); } - r.copy_to_slice(&mut buf[1..8]); - u64::from_be_bytes(buf) + (body << 56) + | (u64::from(r.get_u8()) << 48) + | (u64::from(r.get_u8()) << 40) + | (u64::from(r.get_u8()) << 32) + | (u64::from(r.get_u8()) << 24) + | (u64::from(r.get_u8()) << 16) + | (u64::from(r.get_u8()) << 8) + | u64::from(r.get_u8()) } _ => unreachable!(), }; @@ -197,13 +207,13 @@ impl VarInt { pub fn encode(&self, w: &mut B) { let x = self.0; - if x < 2u64.pow(6) { + if x <= Self::MAX_1BYTE { w.put_u8(x as u8); - } else if x < 2u64.pow(14) { + } else if x <= Self::MAX_2BYTE { w.put_u16((0b01 << 14) | x as u16); - } else if x < 2u64.pow(30) { + } else if x <= Self::MAX_4BYTE { w.put_u32((0b10 << 30) | x as u32); - } else if x < 2u64.pow(62) { + } else if x <= Self::MAX.0 { w.put_u64((0b11 << 62) | x); } else { unreachable!("malformed VarInt") diff --git a/websock-mux-proto/tests/frame.rs b/websock-mux-proto/tests/frame.rs index 3b34056..3a6f1a5 100644 --- a/websock-mux-proto/tests/frame.rs +++ b/websock-mux-proto/tests/frame.rs @@ -56,6 +56,19 @@ fn frame_roundtrip_all_variants() { } } +#[test] +fn frame_encoded_len_matches_actual_size() { + let id = StreamId::new(42, false, StreamDir::Bi).unwrap(); + let frame = Frame::Stream { + id, + data: Bytes::from(vec![1u8; 4096]), + fin: true, + }; + + let encoded = frame.encode(); + assert_eq!(frame.encoded_len(), encoded.len()); +} + #[test] fn frame_decode_unknown_tag() { let mut buf = BytesMut::new(); diff --git a/websock-tungstenite-mux/src/server.rs b/websock-tungstenite-mux/src/server.rs index 744fc38..fdc66df 100644 --- a/websock-tungstenite-mux/src/server.rs +++ b/websock-tungstenite-mux/src/server.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, ToSocketAddrs}; @@ -46,15 +47,15 @@ fn validate_protocols(opts: &websock_proto::ServerOptions) -> Result<()> { } /// Select the first requested subprotocol that appears in the allowed list. -fn select_protocol(req: &server::Request, allowed: &[String]) -> Option { +fn select_protocol<'a>(req: &'a server::Request, allowed: &HashSet) -> Option<&'a str> { if allowed.is_empty() { return None; } let header = req.headers().get(SEC_WEBSOCKET_PROTOCOL)?; let header = header.to_str().ok()?; for candidate in header.split(',').map(|s| s.trim()) { - if allowed.iter().any(|p| p == candidate) { - return Some(candidate.to_string()); + if allowed.contains(candidate) { + return Some(candidate); } } None @@ -82,12 +83,13 @@ where )); } + let allowed: HashSet = opts.protocols.into_iter().collect(); let acceptor = tls.map(|cfg| TlsAcceptor::from(Arc::new(cfg))); Ok(Server { listener, - opts, - headers, + allowed: Arc::new(allowed), + headers: Arc::new(headers), acceptor, limits, }) @@ -95,8 +97,8 @@ where pub struct Server { listener: TcpListener, - opts: websock_proto::ServerOptions, - headers: Vec<(HeaderName, HeaderValue)>, + allowed: Arc>, + headers: Arc>, acceptor: Option, limits: Limits, } @@ -119,19 +121,19 @@ impl Server { (Box::new(stream), false) }; - let headers = self.headers.clone(); - let allowed = self.opts.protocols.clone(); + let headers = Arc::clone(&self.headers); + let allowed = Arc::clone(&self.allowed); let ws = accept_hdr_async( stream, move |req: &server::Request, mut resp: server::Response| { // Additional headers from configuration - for (name, value) in &headers { + for (name, value) in headers.iter() { resp.headers_mut().append(name, value.clone()); } // websock-mux is required protocol - let Some(protocol) = select_protocol(req, &allowed) else { + let Some(protocol) = select_protocol(req, allowed.as_ref()) else { return Err(http::Response::builder() .status(http::StatusCode::BAD_REQUEST) .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) @@ -148,7 +150,7 @@ impl Server { resp.headers_mut().insert( http::header::SEC_WEBSOCKET_PROTOCOL, - http::HeaderValue::from_str(&protocol).expect("validated"), + http::HeaderValue::from_str(protocol).expect("validated"), ); Ok(resp) diff --git a/websock-tungstenite-mux/src/session.rs b/websock-tungstenite-mux/src/session.rs index f4b5bdf..a9815d1 100644 --- a/websock-tungstenite-mux/src/session.rs +++ b/websock-tungstenite-mux/src/session.rs @@ -226,14 +226,24 @@ impl RecvStream { } pub async fn read_buf(&mut self, buf: &mut B) -> Result> { - let mut temp = vec![0u8; 4096]; - let size_opt = self.read(&mut temp).await?; - if let Some(size) = size_opt { - buf.put_slice(&temp[..size]); - Ok(Some(size)) - } else { - Ok(None) + if self.finished { + return Ok(None); + } + if buf.remaining_mut() == 0 { + return Ok(Some(0)); + } + if self.pending.is_empty() { + if let Some(chunk) = self.read_chunk_internal().await? { + self.pending = chunk; + } else { + return Ok(None); + } } + + let amt = self.pending.len().min(buf.remaining_mut()); + buf.put_slice(&self.pending[..amt]); + self.pending = self.pending.slice(amt..); + Ok(Some(amt)) } pub async fn read_chunk_internal(&mut self) -> Result> { @@ -361,7 +371,7 @@ impl SessionInner { Ok(f) => f, Err(_) => break, }; - if inbound.clone().handle_frame(frame).await.is_err() { + if inbound.handle_frame(frame).await.is_err() { break; } } @@ -387,7 +397,7 @@ impl SessionInner { }); } - pub(crate) async fn handle_frame(self: Arc, frame: Frame) -> Result<()> { + pub(crate) async fn handle_frame(self: &Arc, frame: Frame) -> Result<()> { match frame { Frame::OpenUni { id } => { if id.dir() != StreamDir::Uni { @@ -399,24 +409,10 @@ impl SessionInner { return self.protocol_error(1, "OpenUni with wrong initiator").await; } - { - let streams = self.streams.lock().await; - if streams.contains_key(&id) { - return self.protocol_error(1, "duplicate stream open").await; - } - } - - { - let streams = self.streams.lock().await; - if streams.len() >= self.limits.max_open_streams { - return self.protocol_error(3, "too many open streams").await; - } - if streams.contains_key(&id) { - return self.protocol_error(1, "duplicate stream open").await; - } - } - - let recv = self.register_recv_stream(id).await; + let recv = match self.try_register_inbound_recv_stream(id).await { + Ok(recv) => recv, + Err((code, reason)) => return self.protocol_error(code, reason).await, + }; let tx = { self.accept_uni_tx.lock().await.clone() }; if let Some(tx) = tx { @@ -444,24 +440,10 @@ impl SessionInner { return self.protocol_error(1, "OpenBi with wrong initiator").await; } - { - let streams = self.streams.lock().await; - if streams.contains_key(&id) { - return self.protocol_error(1, "duplicate stream open").await; - } - } - - { - let streams = self.streams.lock().await; - if streams.len() >= self.limits.max_open_streams { - return self.protocol_error(3, "too many open streams").await; - } - if streams.contains_key(&id) { - return self.protocol_error(1, "duplicate stream open").await; - } - } - - let recv = self.register_recv_stream(id).await; + let recv = match self.try_register_inbound_recv_stream(id).await { + Ok(recv) => recv, + Err((code, reason)) => return self.protocol_error(code, reason).await, + }; let send = SendStream::new(id, self.clone()); let tx = { self.accept_bi_tx.lock().await.clone() }; @@ -525,6 +507,22 @@ impl SessionInner { Ok(()) } + async fn try_register_inbound_recv_stream( + self: &Arc, + id: StreamId, + ) -> std::result::Result { + let (tx, rx) = mpsc::channel(self.limits.recv_event_queue_len); + let mut streams = self.streams.lock().await; + if streams.len() >= self.limits.max_open_streams { + return Err((3, "too many open streams")); + } + if streams.contains_key(&id) { + return Err((1, "duplicate stream open")); + } + streams.insert(id, tx); + Ok(RecvStream::new(id, self.clone(), rx)) + } + pub(crate) async fn register_recv_stream(self: &Arc, id: StreamId) -> RecvStream { let (tx, rx) = mpsc::channel(self.limits.recv_event_queue_len); let mut streams = self.streams.lock().await; diff --git a/websock-tungstenite/src/server.rs b/websock-tungstenite/src/server.rs index 36da7b3..0d04dbe 100644 --- a/websock-tungstenite/src/server.rs +++ b/websock-tungstenite/src/server.rs @@ -1,5 +1,6 @@ //! Server-side WebSocket acceptor for the Tokio Tungstenite transport. +use std::collections::HashSet; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, ToSocketAddrs}; @@ -31,8 +32,8 @@ where Ok(Server { listener, - opts, - headers, + protocols: Arc::new(opts.protocols.into_iter().collect()), + headers: Arc::new(headers), acceptor, }) } @@ -48,8 +49,8 @@ pub type ServerStream = Box; /// WebSocket server listener. pub struct Server { listener: TcpListener, - opts: ServerOptions, - headers: Vec<(HeaderName, HeaderValue)>, + protocols: Arc>, + headers: Arc>, acceptor: Option, } @@ -81,19 +82,19 @@ impl Server { is_tls, }; - let headers = self.headers.to_vec(); - let protocols = self.opts.protocols.clone(); + let headers = Arc::clone(&self.headers); + let protocols = Arc::clone(&self.protocols); let ws = tokio_tungstenite::accept_hdr_async( stream, move |req: &Request, mut resp: Response| { - for (name, value) in &headers { + for (name, value) in headers.iter() { resp.headers_mut().append(name, value.clone()); } - if let Some(protocol) = select_protocol(req, &protocols) { + if let Some(protocol) = select_protocol(req, protocols.as_ref()) { let value = - HeaderValue::from_str(&protocol).expect("protocol value validated on bind"); + HeaderValue::from_str(protocol).expect("protocol value validated on bind"); resp.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value); } @@ -124,8 +125,8 @@ impl Server { .await .map_err(|e| Error::Tls(e.to_string()))?; - let headers = self.headers.clone(); - let protocols = self.opts.protocols.clone(); + let headers = Arc::clone(&self.headers); + let protocols = Arc::clone(&self.protocols); let info = ConnectionInfo { peer: tls_stream @@ -144,13 +145,13 @@ impl Server { let ws = tokio_tungstenite::accept_hdr_async( tls_stream, move |req: &Request, mut resp: Response| { - for (name, value) in &headers { + for (name, value) in headers.iter() { resp.headers_mut().append(name, value.clone()); } - if let Some(protocol) = select_protocol(req, &protocols) { + if let Some(protocol) = select_protocol(req, protocols.as_ref()) { let value = - HeaderValue::from_str(&protocol).expect("protocol value validated on bind"); + HeaderValue::from_str(protocol).expect("protocol value validated on bind"); resp.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value); } @@ -194,15 +195,15 @@ fn validate_protocols(opts: &ServerOptions) -> Result<()> { } /// Select the first requested subprotocol that appears in the allowed list. -fn select_protocol(req: &Request, allowed: &[String]) -> Option { +fn select_protocol<'a>(req: &'a Request, allowed: &HashSet) -> Option<&'a str> { if allowed.is_empty() { return None; } let header = req.headers().get(SEC_WEBSOCKET_PROTOCOL)?; let header = header.to_str().ok()?; for candidate in header.split(',').map(|s| s.trim()) { - if allowed.iter().any(|p| p == candidate) { - return Some(candidate.to_string()); + if allowed.contains(candidate) { + return Some(candidate); } } None diff --git a/websock-wasm-mux/src/session.rs b/websock-wasm-mux/src/session.rs index a14b469..d3bbee4 100644 --- a/websock-wasm-mux/src/session.rs +++ b/websock-wasm-mux/src/session.rs @@ -213,14 +213,24 @@ impl RecvStream { } pub async fn read_buf(&mut self, buf: &mut B) -> Result> { - let mut temp = vec![0u8; 4096]; - let size_opt = self.read(&mut temp).await?; - if let Some(size) = size_opt { - buf.put_slice(&temp[..size]); - Ok(Some(size)) - } else { - Ok(None) + if self.finished { + return Ok(None); } + if buf.remaining_mut() == 0 { + return Ok(Some(0)); + } + if self.pending.is_empty() { + if let Some(chunk) = self.read_chunk_internal().await? { + self.pending = chunk; + } else { + return Ok(None); + } + } + + let amt = self.pending.len().min(buf.remaining_mut()); + buf.put_slice(&self.pending[..amt]); + self.pending = self.pending.slice(amt..); + Ok(Some(amt)) } async fn read_chunk_internal(&mut self) -> Result> { @@ -340,7 +350,7 @@ impl SessionInner { break; } }; - if inner.clone().handle_frame(frame).await.is_err() { + if inner.handle_frame(frame).await.is_err() { break; } } @@ -379,7 +389,7 @@ impl SessionInner { .map_err(|e| Error::Protocol(format!("stream id overflow: {}", e))) } - async fn handle_frame(self: Rc, frame: Frame) -> Result<()> { + async fn handle_frame(self: &Rc, frame: Frame) -> Result<()> { match frame { Frame::OpenUni { id } => { if id.dir() != StreamDir::Uni {