diff --git a/Cargo.toml b/Cargo.toml index 9e0b871..d010e86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,14 @@ path = "examples/simple_server.rs" name = "headless_server" path = "examples/headless_server.rs" +[[example]] +name = "generic_stream" +path = "examples/generic_stream.rs" + +[[example]] +name = "from_socket_demo" +path = "examples/from_socket_demo.rs" + [profile.release] lto = true # Link-time optimization codegen-units = 1 # Better optimization diff --git a/examples/from_socket_demo.rs b/examples/from_socket_demo.rs new file mode 100644 index 0000000..f312f54 --- /dev/null +++ b/examples/from_socket_demo.rs @@ -0,0 +1,132 @@ +// Copyright 2025 Dustin McAfee +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Simple demonstration of using `from_socket` to accept VNC connections +//! from any stream that implements `AsyncRead + AsyncWrite + Unpin + Send + Sync`. +//! +//! This example shows how to: +//! 1. Create a VNC server +//! 2. Accept TCP connections using `from_socket` +//! 3. Handle different types of streams +//! +//! Usage: +//! cargo run --example from_socket_demo +//! +//! Then connect with a VNC viewer to localhost:5900 + +use rustvncserver::VncServer; +use std::error::Error; +use std::sync::Arc; +use tokio::net::TcpListener; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + env_logger::init(); + + println!("VNC Server with from_socket() demonstration"); + println!("==========================================="); + + // Create VNC server + let (server, mut events) = VncServer::new( + 800, + 600, + "from_socket Demo".to_string(), + None, // No password + ); + let server = Arc::new(server); + + // Handle server events in background + let _server_for_events = Arc::clone(&server); + tokio::spawn(async move { + while let Some(event) = events.recv().await { + match event { + rustvncserver::server::ServerEvent::ClientConnected { client_id } => { + println!("[Event] Client {} connected", client_id); + } + rustvncserver::server::ServerEvent::ClientDisconnected { client_id } => { + println!("[Event] Client {} disconnected", client_id); + } + rustvncserver::server::ServerEvent::KeyPress { client_id, down, key } => { + let action = if down { "pressed" } else { "released" }; + println!("[Event] Client {} key {} {}", client_id, key, action); + } + rustvncserver::server::ServerEvent::PointerMove { client_id, x, y, button_mask } => { + println!("[Event] Client {} pointer at ({}, {}) buttons: {:08b}", + client_id, x, y, button_mask); + } + rustvncserver::server::ServerEvent::CutText { client_id, text } => { + let preview = if text.len() > 20 { + format!("{}...", &text[..20]) + } else { + text.clone() + }; + println!("[Event] Client {} sent clipboard: {}", client_id, preview); + } + } + } + }); + + // Create a simple test pattern + let mut pixels = vec![0u8; 800 * 600 * 4]; + for y in 0..600 { + for x in 0..800 { + let offset = (y * 800 + x) * 4; + pixels[offset] = (x * 255 / 800) as u8; // Red gradient + pixels[offset + 1] = (y * 255 / 600) as u8; // Green gradient + pixels[offset + 2] = 128; // Blue constant + pixels[offset + 3] = 255; // Alpha opaque + } + } + + // Update framebuffer + server + .framebuffer() + .update_cropped(&pixels, 0, 0, 800, 600) + .await + .expect("Failed to update framebuffer"); + + println!("Framebuffer initialized with test pattern"); + + // Create TCP listener + let listener = TcpListener::bind("127.0.0.1:5900").await?; + println!("TCP listener ready on port 5900"); + println!("Connect with: vncviewer localhost:5900"); + println!("Waiting for connections..."); + + // Accept connections and handle them using from_socket + loop { + match listener.accept().await { + Ok((stream, addr)) => { + println!("New connection from {}", addr); + + // Use from_socket to handle the connection + let server_clone = Arc::clone(&server); + tokio::spawn(async move { + match server_clone.from_socket(stream, None).await { + Ok(()) => { + println!("Connection from {} handled successfully", addr); + } + Err(e) => { + eprintln!("Failed to handle connection from {}: {}", addr, e); + } + } + }); + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + } +} diff --git a/examples/generic_stream.rs b/examples/generic_stream.rs new file mode 100644 index 0000000..fc0a520 --- /dev/null +++ b/examples/generic_stream.rs @@ -0,0 +1,251 @@ +// Copyright 2025 Dustin McAfee +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Generic stream VNC server example. +//! +//! This example demonstrates how to use the VNC server with different types of streams +//! that implement `AsyncRead + AsyncWrite + Unpin + Send`, such as: +//! - TCP streams (standard VNC) +//! - UDP streams with reliability layer +//! - WebSocket connections +//! - Custom transport protocols +//! +//! Usage: +//! cargo run --example generic_stream +//! +//! This example creates a simple TCP listener and accepts connections using `from_socket`. + +use rustvncserver::VncServer; +use std::error::Error; +use std::sync::Arc; +use tokio::net::TcpListener; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + env_logger::init(); + + println!("Starting generic stream VNC server example..."); + println!("This example demonstrates using from_socket() with different stream types"); + + // Create VNC server + let (server, mut events) = VncServer::new( + 800, + 600, + "Generic Stream VNC".to_string(), + None, // No password + ); + let server = Arc::new(server); + + // Handle server events in background + tokio::spawn(async move { + while let Some(event) = events.recv().await { + match event { + rustvncserver::server::ServerEvent::ClientConnected { client_id } => { + println!("Client {} connected via generic stream", client_id); + } + rustvncserver::server::ServerEvent::ClientDisconnected { client_id } => { + println!("Client {} disconnected", client_id); + } + rustvncserver::server::ServerEvent::KeyPress { client_id, down, key } => { + let action = if down { "pressed" } else { "released" }; + println!("Client {} key {} {}", client_id, key, action); + } + rustvncserver::server::ServerEvent::PointerMove { client_id, x, y, button_mask } => { + println!("Client {} pointer moved to ({}, {}) buttons: {:08b}", + client_id, x, y, button_mask); + } + rustvncserver::server::ServerEvent::CutText { client_id, text } => { + println!("Client {} sent cut text: {}...", + client_id, text.chars().take(20).collect::()); + } + } + } + }); + + // Create a test pattern + let mut pixels = vec![0u8; 800 * 600 * 4]; + for y in 0..600 { + for x in 0..800 { + let offset = (y * 800 + x) * 4; + pixels[offset] = (x * 255 / 800) as u8; // R gradient + pixels[offset + 1] = (y * 255 / 600) as u8; // G gradient + pixels[offset + 2] = 128; // B constant + pixels[offset + 3] = 255; // A opaque + } + } + + // Update framebuffer + server + .framebuffer() + .update_cropped(&pixels, 0, 0, 800, 600) + .await + .expect("Failed to update framebuffer"); + + println!("Framebuffer initialized with test pattern"); + + // Example 1: Standard TCP listener using from_socket + println!("\nExample 1: Standard TCP listener on port 5901"); + let tcp_listener = TcpListener::bind("127.0.0.1:5901").await?; + println!("TCP listener ready on port 5901"); + + let server_clone = Arc::clone(&server); + tokio::spawn(async move { + loop { + match tcp_listener.accept().await { + Ok((stream, addr)) => { + println!("Accepted TCP connection from {}", addr); + + // Use from_socket to handle the TCP stream + if let Err(e) = server_clone.from_socket(stream, None).await { + eprintln!("Failed to handle TCP connection: {}", e); + } + } + Err(e) => { + eprintln!("Error accepting TCP connection: {}", e); + } + } + } + }); + + // Example 2: Custom stream wrapper demonstration + println!("\nExample 2: Custom stream wrapper"); + println!("This shows how you could wrap different transport protocols"); + + // Create a simple TCP server on another port to demonstrate + let server_clone2 = Arc::clone(&server); + tokio::spawn(async move { + let listener = match TcpListener::bind("127.0.0.1:5902").await { + Ok(l) => l, + Err(e) => { + eprintln!("Failed to bind port 5902: {}", e); + return; + } + }; + + println!("Custom stream server ready on port 5902"); + + while let Ok((stream, addr)) = listener.accept().await { + println!("Custom stream connection from {}", addr); + + // Example: You could wrap the stream here with custom logic + // For example, add compression, encryption, or protocol translation + let wrapped_stream = ExampleStreamWrapper::new(stream); + + if let Err(e) = server_clone2.from_socket(wrapped_stream, None).await { + eprintln!("Failed to handle wrapped stream: {}", e); + } + } + }); + + println!("\nServers are running:"); + println!("- Standard VNC on port 5900 (using server.listen())"); + println!("- Generic stream TCP on port 5901 (using from_socket())"); + println!("- Custom wrapped stream on port 5902"); + println!("\nConnect with:"); + println!(" vncviewer localhost:5900"); + println!(" vncviewer localhost:5901"); + println!(" vncviewer localhost:5902"); + println!("\nPress Ctrl+C to stop"); + + // Also start the standard listen method for comparison + let server_ref = Arc::clone(&server); + tokio::spawn(async move { + if let Err(e) = server_ref.listen(5900).await { + eprintln!("Server error: {}", e); + } + }); + + println!("Servers are running. Press Ctrl+C to stop."); + println!("Waiting for connections..."); + + // Keep main thread alive by waiting for a long time + tokio::time::sleep(tokio::time::Duration::from_secs(3600)).await; + + Ok(()) +} + +/// Example stream wrapper that demonstrates how to implement custom transport layers. +/// +/// This struct wraps any stream that implements `AsyncRead + AsyncWrite + Unpin` +/// and adds custom behavior (in this case, just logging). +struct ExampleStreamWrapper { + inner: S, + bytes_transferred: usize, +} + +impl ExampleStreamWrapper { + fn new(stream: S) -> Self { + Self { + inner: stream, + bytes_transferred: 0, + } + } +} + +impl tokio::io::AsyncRead for ExampleStreamWrapper { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let before = buf.filled().len(); + let result = std::pin::Pin::new(&mut self.inner).poll_read(cx, buf); + let after = buf.filled().len(); + + if after > before { + self.bytes_transferred += after - before; + println!("Read {} bytes (total: {})", after - before, self.bytes_transferred); + } + + result + } +} + +impl tokio::io::AsyncWrite for ExampleStreamWrapper { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let result = std::pin::Pin::new(&mut self.inner).poll_write(cx, buf); + + if let std::task::Poll::Ready(Ok(bytes_written)) = &result { + self.bytes_transferred += bytes_written; + println!("Wrote {} bytes (total: {})", bytes_written, self.bytes_transferred); + } + + result + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +// Implement Unpin since S is Unpin +impl Unpin for ExampleStreamWrapper {} + +// Implement Send since S is Send +unsafe impl Send for ExampleStreamWrapper {} diff --git a/src/client.rs b/src/client.rs index b28f37d..f1d692e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -43,8 +43,7 @@ use log::info; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc; use tokio::sync::RwLock; @@ -211,10 +210,10 @@ impl TightStreamCompressor for TightZlibStreams { /// processing incoming client messages (e.g., key events, pointer events, pixel format requests), /// and managing client-specific settings like preferred encodings and JPEG quality. pub struct VncClient { - /// The read half of the TCP stream for receiving client messages. - read_stream: tokio::net::tcp::OwnedReadHalf, - /// The write half of the TCP stream for sending updates to the client. - write_stream: Arc>, + /// The read half of the stream for receiving client messages. + read_stream: Box, + /// The write half of the stream for sending updates to the client. + write_stream: Arc>>, /// A reference to the framebuffer, used to retrieve pixel data for updates. framebuffer: Framebuffer, /// The pixel format requested by the client, protected by a `RwLock` for concurrent access. @@ -275,8 +274,8 @@ pub struct VncClient { /// Persistent zlib compression streams for Tight encoding (4 streams with dictionaries). /// Protected by `RwLock` since encoding happens during `send_batched_update`. tight_zlib_streams: RwLock, - /// Remote host address (IP:port) of the connected client - remote_host: String, + /// Remote host address (IP:port) of the connected client (None for generic streams) + remote_host: Option, /// Destination port for repeater connections (None for direct connections) destination_port: Option, /// Repeater ID for repeater connections (None for direct connections) @@ -294,7 +293,7 @@ impl VncClient { /// # Arguments /// /// * `client_id` - The unique client ID assigned by the server. - /// * `stream` - The `TcpStream` representing the established connection to the VNC client. + /// * `stream` - A stream implementing `AsyncRead + AsyncWrite + Unpin + Send` representing the connection to the VNC client. /// * `framebuffer` - The `Framebuffer` instance that this client will receive updates from. /// * `desktop_name` - The name of the desktop to be sent to the client during `ServerInit`. /// * `password` - An optional password for VNC authentication. If `Some`, VNC authentication @@ -306,21 +305,22 @@ impl VncClient { /// /// A `Result` which is `Ok(VncClient)` on successful handshake and initialization, or /// `Err(std::io::Error)` if an I/O error occurs during communication or handshake. - pub async fn new( + pub async fn new( client_id: usize, - mut stream: TcpStream, + mut stream: S, framebuffer: Framebuffer, desktop_name: String, password: Option, event_tx: mpsc::UnboundedSender, - ) -> Result { + ) -> Result + where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + { // Capture remote host address before handshake - let remote_host = stream - .peer_addr() - .map_or_else(|_| "unknown".to_string(), |addr| addr.to_string()); + let remote_host = None; // Generic streams may not have peer_addr // Disable Nagle's algorithm for immediate frame delivery - stream.set_nodelay(true)?; + //stream.set_nodelay(true)?; // Send protocol version stream.write_all(PROTOCOL_VERSION.as_bytes()).await?; @@ -389,13 +389,13 @@ impl VncClient { log::info!("VNC client handshake completed"); // Split stream into read/write halves for lock-free shutdown - let (read_stream, write_stream) = stream.into_split(); + let (read_stream, write_stream) = tokio::io::split(stream); let creation_time = Instant::now(); Ok(Self { - read_stream, - write_stream: Arc::new(tokio::sync::Mutex::new(write_stream)), + read_stream: Box::new(read_stream), + write_stream: Arc::new(tokio::sync::Mutex::new(Box::new(write_stream))), framebuffer, pixel_format: RwLock::new(PixelFormat::rgba32()), encodings: RwLock::new(vec![ENCODING_RAW]), @@ -484,7 +484,7 @@ impl VncClient { /// Enters the main message loop for the `VncClient`, handling incoming data from the client /// and periodically sending framebuffer updates. /// - /// This function continuously reads from the client's `TcpStream` and processes VNC messages + /// This function continuously reads from the client's stream and processes VNC messages /// such as `SetPixelFormat`, `SetEncodings`, `FramebufferUpdateRequest`, `KeyEvent`, /// `PointerEvent`, and `ClientCutText`. It also uses a `tokio::time::interval` to /// periodically check if batched framebuffer updates should be sent to the client, @@ -1702,13 +1702,13 @@ impl VncClient { /// which will cause reads on the read half to fail naturally. pub fn get_write_stream_handle( &self, - ) -> Arc> { + ) -> Arc>> { self.write_stream.clone() } /// Returns the remote host address of the connected client. pub fn get_remote_host(&self) -> &str { - &self.remote_host + self.remote_host.as_deref().unwrap_or("unknown") } /// Returns the destination port for repeater connections. diff --git a/src/server.rs b/src/server.rs index c64f460..1f14db7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -33,6 +33,7 @@ use log::error; use log::info; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, RwLock}; @@ -49,6 +50,7 @@ static NEXT_CLIENT_ID: AtomicU64 = AtomicU64::new(1); /// Represents a VNC server instance. /// /// This struct manages the VNC framebuffer, connected clients, and handles server-wide events. +#[derive(Clone)] pub struct VncServer { /// The VNC framebuffer, representing the remote desktop screen. framebuffer: Framebuffer, @@ -60,7 +62,7 @@ pub struct VncServer { clients: Arc>>>>, /// Write stream handles for direct socket shutdown client_write_streams: - Arc>>>>, + Arc>>>>>, /// Task handles for waiting on client threads to exit client_tasks: Arc>>>, /// List of active client IDs for fast lookup during shutdown without locking `VncClient` objects. @@ -139,10 +141,11 @@ impl VncServer { desktop_name: String, password: Option, ) -> (Self, mpsc::UnboundedReceiver) { + let framebuffer = Framebuffer::new(width, height); let (event_tx, event_rx) = mpsc::unbounded_channel(); let server = Self { - framebuffer: Framebuffer::new(width, height), + framebuffer, desktop_name, password, clients: Arc::new(RwLock::new(Vec::new())), @@ -229,6 +232,73 @@ impl VncServer { } } + /// Accept a VNC client connection from a generic stream. + /// + /// This method allows accepting VNC connections from any stream that implements + /// `AsyncRead + AsyncWrite + Unpin + Send`, such as TCP, UDP with reliability layer, + /// WebSocket, or other custom transports. + /// + /// # Arguments + /// + /// * `stream` - A stream implementing `AsyncRead + AsyncWrite + Unpin + Send` + /// * `client_id` - Optional client ID. If None, a new ID will be generated. + /// + /// # Returns + /// + /// Returns `Ok(())` if the client was successfully handled, or an `std::io::Error` on failure. + pub async fn from_socket( + &self, + stream: S, + client_id: Option, + ) -> Result<(), std::io::Error> + where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + { + // Generate client ID if not provided + let client_id = client_id.unwrap_or_else(|| { + let client_id_raw = NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst); + if client_id_raw == 0 || client_id_raw >= u64::MAX - 1000 { + // Wrap around to 1 if overflow + NEXT_CLIENT_ID.store(1, Ordering::SeqCst); + 1 + } else { + client_id_raw as usize + } + }); + + let framebuffer = self.framebuffer.clone(); + let desktop_name = self.desktop_name.clone(); + let password = self.password.clone(); + let clients = self.clients.clone(); + let client_write_streams = self.client_write_streams.clone(); + let client_tasks = self.client_tasks.clone(); + let client_ids = self.client_ids.clone(); + let server_event_tx = self.event_tx.clone(); + + let handle = tokio::spawn(async move { + if let Err(e) = Self::handle_client( + stream, + client_id, + framebuffer, + desktop_name, + password, + clients, + client_write_streams, + client_tasks, + client_ids, + server_event_tx, + ) + .await + { + error!("Client {client_id} error: {e}"); + } + }); + + // Store the handle_client task handle for joining later + self.client_tasks.write().await.push(handle); + Ok(()) + } + /// Handles a newly connected VNC client through its entire lifecycle. /// /// This function performs the VNC handshake, creates a `VncClient` instance, spawns @@ -253,20 +323,23 @@ impl VncServer { /// /// `Ok(())` when the client disconnects normally, or `Err` if an I/O error occurs. #[allow(clippy::too_many_arguments)] // VNC protocol handler requires all shared server state - async fn handle_client( - stream: TcpStream, + async fn handle_client( + stream: S, client_id: usize, framebuffer: Framebuffer, desktop_name: String, password: Option, clients: Arc>>>>, client_write_streams: Arc< - RwLock>>>, + RwLock>>>>, >, client_tasks: Arc>>>, client_ids: Arc>>, server_event_tx: mpsc::UnboundedSender, - ) -> Result<(), std::io::Error> { + ) -> Result<(), std::io::Error> + where + S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + { let (client_event_tx, mut client_event_rx) = mpsc::unbounded_channel(); let client = VncClient::new( @@ -305,7 +378,7 @@ impl VncServer { let client_arc_clone = client_arc.clone(); let msg_handle = tokio::spawn(async move { let result = { - let mut client = client_arc_clone.write().await; + let mut client: tokio::sync::RwLockWriteGuard<'_, VncClient> = client_arc_clone.write().await; client.handle_messages().await }; if let Err(e) = result { @@ -486,14 +559,14 @@ impl VncServer { ); match client_result { - Ok(mut client) => { - // Set connection metadata for client management APIs - client.set_connection_metadata(Some(port)); - + Ok(client) => { log::info!("Reverse connection {client_id} established"); let client_arc = Arc::new(RwLock::new(client)); + // Set connection metadata for client management APIs + client_arc.write().await.set_connection_metadata(Some(port)); + // Register client to receive dirty region notifications let regions_arc = client_arc.read().await.get_receiver_handle(); let receiver = DirtyRegionReceiver::new(Arc::downgrade(®ions_arc)); @@ -501,32 +574,27 @@ impl VncServer { // Store the write stream handle for direct socket shutdown let write_stream_handle = { - let client = client_arc.read().await; - client.get_write_stream_handle() + let client_guard = client_arc.read().await; + client_guard.get_write_stream_handle() }; client_write_streams.write().await.push(write_stream_handle); clients.write().await.push(client_arc.clone()); client_ids.write().await.push(client_id); - let _ = - server_event_tx.send(ServerEvent::ClientConnected { client_id }); + let _ = server_event_tx.send(ServerEvent::ClientConnected { client_id }); // Spawn task to handle client messages let client_arc_clone = client_arc.clone(); let msg_handle = tokio::spawn(async move { let result = { - let mut client = client_arc_clone.write().await; + let mut client: tokio::sync::RwLockWriteGuard<'_, VncClient> = client_arc_clone.write().await; client.handle_messages().await }; if let Err(e) = result { - error!( - "Reverse client {client_id} message handling error: {e}" - ); + error!("Client {client_id} error: {e}"); } }); - - // Store the message handler task handle client_tasks.write().await.push(msg_handle); // Handle client events @@ -697,7 +765,7 @@ impl VncServer { let client_arc_clone = client_arc.clone(); let msg_handle = tokio::spawn(async move { let result = { - let mut client = client_arc_clone.write().await; + let mut client: tokio::sync::RwLockWriteGuard<'_, VncClient> = client_arc_clone.write().await; client.handle_messages().await }; if let Err(e) = result {