diff --git a/Cargo.toml b/Cargo.toml index 7cbf1b2..3fff453 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,11 @@ [package] name = "quad-net" -version = "0.1.1" +version = "0.2.0" authors = ["Fedor Logachev "] edition = "2018" license = "MIT/Apache-2.0" +homepage = "https://github.com/not-fl3/quad-net" +repository = "https://github.com/not-fl3/quad-net" description = "Miniquad friendly network abstractions" [features] diff --git a/examples/client.rs b/examples/client.rs index 3647f97..caec937 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -36,7 +36,7 @@ async fn main() { if is_mouse_button_down(MouseButton::Left) { let (x, y) = mouse_position(); - socket.send_bin(&(x, y)); + socket.send_bin(&(x, y)).unwrap(); } next_frame().await } diff --git a/js/quad-net.js b/js/quad-net.js index ed494a3..1e92fca 100644 --- a/js/quad-net.js +++ b/js/quad-net.js @@ -11,7 +11,7 @@ register_plugin = function (importObject) { importObject.env.http_try_recv = http_try_recv; } -miniquad_add_plugin({ register_plugin, on_init, version: "0.1.1", name: "quad_net" }); +miniquad_add_plugin({ register_plugin, on_init, version: "0.2.0", name: "quad_net" }); var quad_socket; var connected = 0; diff --git a/src/quad_socket/client.rs b/src/quad_socket/client.rs index 44902e8..4ad7446 100644 --- a/src/quad_socket/client.rs +++ b/src/quad_socket/client.rs @@ -15,16 +15,18 @@ pub struct QuadSocket { } impl QuadSocket { - pub fn send(&mut self, data: &[u8]) { + pub fn send(&mut self, data: &[u8]) -> Result<(), Error> { #[cfg(not(target_arch = "wasm32"))] { - self.tcp_socket.send(data); + self.tcp_socket.send(data)?; } #[cfg(target_arch = "wasm32")] { self.web_socket.send_bytes(data); } + + Ok(()) } pub fn try_recv(&mut self) -> Option> { @@ -42,10 +44,10 @@ impl QuadSocket { #[cfg(feature = "nanoserde")] impl QuadSocket { - pub fn send_bin(&mut self, data: &T) { + pub fn send_bin(&mut self, data: &T) -> Result<(), Error> { use nanoserde::SerBin; - self.send(&SerBin::serialize_bin(data)); + self.send(&SerBin::serialize_bin(data)) } pub fn try_recv_bin(&mut self) -> Option { diff --git a/src/quad_socket/client/tcp.rs b/src/quad_socket/client/tcp.rs index 7e6c8de..778570f 100644 --- a/src/quad_socket/client/tcp.rs +++ b/src/quad_socket/client/tcp.rs @@ -1,7 +1,7 @@ use std::net::ToSocketAddrs; use std::net::TcpStream; -use std::sync::mpsc::{self, Receiver}; +use std::sync::mpsc::{self, Receiver, SendError}; use crate::{error::Error, quad_socket::protocol::MessageReader}; @@ -11,11 +11,11 @@ pub struct TcpSocket { } impl TcpSocket { - pub fn send(&mut self, data: &[u8]) { - use std::io::Write; + pub fn send(&mut self, data: &[u8]) -> Result<(), Error> { + write_until_done(&mut self.stream, &u32::to_be_bytes(data.len() as u32))?; + write_until_done(&mut self.stream, data)?; - self.stream.write(&[data.len() as u8]).unwrap(); - self.stream.write(data).unwrap(); + Ok(()) } pub fn try_recv(&mut self) -> Option> { @@ -23,10 +23,23 @@ impl TcpSocket { } } +fn write_until_done(stream: &mut TcpStream, message: &[u8]) -> Result<(), Error> { + use std::io::Write; + let mut sent = 0; + + while sent < message.len() { + sent += stream.write(&message[sent..]) + .map_err(Error::IOError)?; + } + + Ok(()) +} + impl TcpSocket { pub fn connect(addr: A) -> Result { let stream = TcpStream::connect(addr)?; stream.set_nodelay(true).unwrap(); + stream.set_nonblocking(true).unwrap(); let (tx, rx) = mpsc::channel(); @@ -35,8 +48,15 @@ impl TcpSocket { move || { let mut messages = MessageReader::new(); loop { - if let Ok(Some(message)) = messages.next(&mut stream) { - tx.send(message).unwrap(); + match messages.next(&mut stream) { + Ok(Some(message)) => { + match tx.send(message) { + Ok(()) => (), + Err(SendError(_message)) => break, + } + } + Ok(None) => { std::thread::yield_now() }, + Err(()) => break, } } } diff --git a/src/quad_socket/protocol.rs b/src/quad_socket/protocol.rs index cf59531..2583be9 100644 --- a/src/quad_socket/protocol.rs +++ b/src/quad_socket/protocol.rs @@ -1,37 +1,50 @@ use std::io::ErrorKind; #[derive(Debug)] -pub enum MessageReader { - Empty, - Amount(usize), +pub(crate) struct MessageReader { + buffer: Vec, } impl MessageReader { pub fn new() -> MessageReader { - MessageReader::Empty + MessageReader { + buffer: Vec::new() + } } pub fn next(&mut self, mut stream: impl std::io::Read) -> Result>, ()> { - let mut bytes = [0 as u8; 255]; - - match self { - MessageReader::Empty => match stream.read_exact(&mut bytes[0..1]) { - Ok(_) => { - *self = MessageReader::Amount(bytes[0] as usize); - Ok(None) - } - Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(None), - Err(_err) => Err(()), - }, - MessageReader::Amount(len) => match stream.read_exact(&mut bytes[0..*len]) { - Ok(_) => { - let msg = bytes[0..*len].to_vec(); - *self = MessageReader::Empty; - Ok(Some(msg)) - } - Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(None), - Err(_) => Err(()), - }, + let mut bytes = [0; 16 * 1024]; + + let bytes_read = match stream.read(&mut bytes) { + Ok(0) => return Err(()), // Disconnected + Ok(bytes_read) => bytes_read, + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // No bytes received; still, check our buffer in case there's + // more stored messages in it from previous packets + 0 + } + Err(_err) => return Err(()), + }; + + // Read the first 4 bytes, which encode the message's length + self.buffer.extend_from_slice(&bytes[..bytes_read]); + + if self.buffer.len() < 4 { + return Ok(None); } + + use std::convert::TryInto; + let four_bytes: [u8; 4] = self.buffer[0..4].try_into().unwrap(); + let message_size = u32::from_be_bytes(four_bytes) as usize; + + // Keep receiving until the whole message is here + if self.buffer.len() < 4 + message_size { + return Ok(None); + } + + let message = self.buffer[4..4+message_size].to_vec(); + self.buffer.drain(..4+message_size); + + Ok(Some(message)) } } diff --git a/src/quad_socket/server.rs b/src/quad_socket/server.rs index b276847..77484d2 100644 --- a/src/quad_socket/server.rs +++ b/src/quad_socket/server.rs @@ -1,3 +1,4 @@ +use std::io::ErrorKind; use std::net::ToSocketAddrs; use std::net::{TcpListener, TcpStream}; use std::time::{Duration, Instant}; @@ -33,15 +34,13 @@ pub struct SocketHandle<'a> { impl<'a> Sender<'a> { fn send(&mut self, data: &[u8]) -> Option<()> { - use std::io::Write; - match self { Sender::WebSocket(out) => { out.send(data).ok()?; } Sender::Tcp(stream) => { - stream.write(&[data.len() as u8]).ok()?; - stream.write(data).ok()?; + write_until_done(stream, &u32::to_be_bytes(data.len() as u32)); + write_until_done(stream, data); } } @@ -49,6 +48,25 @@ impl<'a> Sender<'a> { } } +fn write_until_done(stream: &mut TcpStream, message: &[u8]) -> Option<()> { + use std::io::Write; + + let mut sent = 0; + + while sent < message.len() { + sent += match stream.write(&message[sent..]) { + Ok(bytes) => bytes, + Err(e) if e.kind() == ErrorKind::WouldBlock => { + std::thread::yield_now(); + 0 + }, + Err(_e) => return None, + }; + } + + Some(()) +} + impl<'a> SocketHandle<'a> { fn new(sender: Sender<'a>) -> SocketHandle<'a> { SocketHandle { @@ -199,8 +217,8 @@ where return; } } - Ok(None) => {} - Err(_err) => { + Ok(None) => std::thread::yield_now(), + Err(()) => { (on_disconnect.lock().unwrap())(&state); return; }