diff --git a/Cargo.toml b/Cargo.toml index e28ca79..c355c0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,19 +10,19 @@ members = [ ] [workspace.package] -version = "0.24.1" +version = "0.25.0" edition = "2024" authors = ["shellrow "] [workspace.dependencies] -nex-core = { version = "0.24.1", path = "nex-core" } -nex-datalink = { version = "0.24.1", path = "nex-datalink" } -nex-packet = { version = "0.24.1", path = "nex-packet" } -nex-sys = { version = "0.24.1", path = "nex-sys" } -nex-socket = { version = "0.24.1", path = "nex-socket" } +nex-core = { version = "0.25.0", path = "nex-core" } +nex-datalink = { version = "0.25.0", path = "nex-datalink" } +nex-packet = { version = "0.25.0", path = "nex-packet" } +nex-sys = { version = "0.25.0", path = "nex-sys" } +nex-socket = { version = "0.25.0", path = "nex-socket" } serde = { version = "1" } libc = "0.2" -netdev = { version = "0.39" } +netdev = { version = "0.40" } bytes = "1" tokio = { version = "1" } rand = "0.8" diff --git a/README.md b/README.md index c9689bd..a48f62d 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,13 @@ Cross-platform low-level networking library in Rust ## Overview `nex` is a Rust library that provides cross-platform low-level networking capabilities. -It includes a set of modules, each with a specific focus: +It includes sub-crates with responsibilities: -- `datalink`: Datalink layer networking. -- `packet`: Low-level packet parsing and building. -- `socket`: Socket-related functionality. +- `nex-packet`: Low-level packet parsing and serialization. +- `nex-datalink`: Raw datalink send/receive backends across platforms. +- `nex-socket`: Low-level socket operations with cross-platform option handling. -## Upcoming Features -The project has plans to enhance nex with the following features: -- More Protocol Support: Expanding protocol support to include additional network protocols and standards. -- Performance Improvements: Continuously working on performance enhancements for faster network operations. +The project aims to expose portable low-level primitives. ## Usage @@ -26,7 +23,7 @@ To use `nex`, add it as a dependency in your `Cargo.toml`: ```toml [dependencies] -nex = "0.24" +nex = "0.25" ``` ## Using Specific Sub-crates diff --git a/nex-core/src/ip.rs b/nex-core/src/ip.rs index 42893ff..d63ee0b 100644 --- a/nex-core/src/ip.rs +++ b/nex-core/src/ip.rs @@ -1,6 +1,6 @@ //! IP address utilities. -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; /// Returns [`true`] if the address appears to be globally routable. pub fn is_global_ip(ip_addr: &IpAddr) -> bool { @@ -10,6 +10,19 @@ pub fn is_global_ip(ip_addr: &IpAddr) -> bool { } } +/// Returns an unspecified IP (`0.0.0.0` / `::`) with the same family as `ip_addr`. +pub fn unspecified_ip_for(ip_addr: &IpAddr) -> IpAddr { + match ip_addr { + IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED), + } +} + +/// Returns an unspecified socket address with the same family as `ip_addr`. +pub fn unspecified_socket_addr_for(ip_addr: &IpAddr, port: u16) -> SocketAddr { + SocketAddr::new(unspecified_ip_for(ip_addr), port) +} + /// Returns [`true`] if the address appears to be globally reachable /// as specified by the [IANA IPv4 Special-Purpose Address Registry]. pub fn is_global_ipv4(ipv4_addr: &Ipv4Addr) -> bool { @@ -139,4 +152,22 @@ mod tests { assert!(!is_global_ip(&ip_private)); assert!(!is_global_ip(&ip_ula)); } + + #[test] + fn test_unspecified_helpers() { + let v4 = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); + let v6 = IpAddr::V6(Ipv6Addr::LOCALHOST); + + assert_eq!(unspecified_ip_for(&v4), IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + assert_eq!(unspecified_ip_for(&v6), IpAddr::V6(Ipv6Addr::UNSPECIFIED)); + + assert_eq!( + unspecified_socket_addr_for(&v4, 1234), + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 1234) + ); + assert_eq!( + unspecified_socket_addr_for(&v6, 4321), + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 4321) + ); + } } diff --git a/nex-core/src/lib.rs b/nex-core/src/lib.rs index 9024c90..852b5b3 100644 --- a/nex-core/src/lib.rs +++ b/nex-core/src/lib.rs @@ -1,5 +1,5 @@ -//! Provides core network types and functionality. -//! Primarily designed for use with nex, it also includes extensions to the standard net module. +//! Core network types and helpers shared across the `nex` crates. +//! Includes interface, MAC/IP, and bitfield utilities used by low-level networking code. pub use netdev; diff --git a/nex-datalink/src/async_io/bpf.rs b/nex-datalink/src/async_io/bpf.rs index c06039b..a8b39b6 100644 --- a/nex-datalink/src/async_io/bpf.rs +++ b/nex-datalink/src/async_io/bpf.rs @@ -151,16 +151,18 @@ impl Stream for AsyncBpfSocketReceiver { /// Create a new asynchronous BPF socket channel. pub fn channel(network_interface: &Interface, config: Config) -> io::Result { #[cfg(any(target_os = "macos", target_os = "ios", target_os = "openbsd"))] - fn get_fd(attempts: usize) -> RawFd { + fn get_fd(attempts: usize) -> io::Result { for i in 0..attempts { let file_name = format!("/dev/bpf{}", i); - let c_file_name = CString::new(file_name).unwrap(); + let c_file_name = CString::new(file_name).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidInput, "invalid bpf device path") + })?; let fd = unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) }; if fd != -1 { - return fd; + return Ok(fd); } } - -1 + Err(io::Error::last_os_error()) } #[cfg(any( target_os = "freebsd", @@ -168,15 +170,18 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result RawFd { - let c_file_name = CString::new("/dev/bpf").unwrap(); - unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) } + fn get_fd(_attempts: usize) -> io::Result { + let c_file_name = CString::new("/dev/bpf") + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid bpf device path"))?; + let fd = unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) }; + if fd == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(fd) + } } - let fd = get_fd(config.bpf_fd_attempts); - if fd == -1 { - return Err(io::Error::last_os_error()); - } + let fd = get_fd(config.bpf_fd_attempts)?; let mut iface: bpf::ifreq = unsafe { mem::zeroed() }; for (i, c) in network_interface.name.bytes().enumerate() { diff --git a/nex-datalink/src/async_io/wpcap.rs b/nex-datalink/src/async_io/wpcap.rs index 465e5c5..f52ba29 100644 --- a/nex-datalink/src/async_io/wpcap.rs +++ b/nex-datalink/src/async_io/wpcap.rs @@ -92,11 +92,29 @@ impl Stream for AsyncWpcapSocketReceiver { type Item = io::Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut queue = self.inner.packets.lock().unwrap(); + let mut queue = match self.inner.packets.lock() { + Ok(queue) => queue, + Err(_) => { + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + "wpcap packet queue mutex poisoned", + )))); + } + }; if let Some(pkt) = queue.pop_front() { Poll::Ready(Some(Ok(pkt))) } else { - *self.inner.waker.lock().unwrap() = Some(cx.waker().clone()); + match self.inner.waker.lock() { + Ok(mut waker) => { + *waker = Some(cx.waker().clone()); + } + Err(_) => { + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + "wpcap waker mutex poisoned", + )))); + } + } Poll::Pending } } @@ -108,7 +126,9 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result io::Result queue, + Err(poisoned) => poisoned.into_inner(), + }; queue.push_back(data); } let offset = (*hdr).bh_hdrlen as isize + (*hdr).bh_caplen as isize; ptr = ptr.offset(bpf::BPF_WORDALIGN(offset)); } } - if let Some(w) = waker.lock().unwrap().take() { + let mut waker = match waker.lock() { + Ok(waker) => waker, + Err(poisoned) => poisoned.into_inner(), + }; + if let Some(w) = waker.take() { w.wake(); } } diff --git a/nex-datalink/src/bpf.rs b/nex-datalink/src/bpf.rs index 703ae6e..b4c88f3 100644 --- a/nex-datalink/src/bpf.rs +++ b/nex-datalink/src/bpf.rs @@ -74,25 +74,30 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result libc::c_int { - let c_file_name = CString::new(&b"/dev/bpf"[..]).unwrap(); - unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) } + fn get_fd(_attempts: usize) -> io::Result { + let c_file_name = CString::new(&b"/dev/bpf"[..]) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid bpf device path"))?; + let fd = unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) }; + if fd == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(fd) + } } #[cfg(any(target_os = "openbsd", target_os = "macos", target_os = "ios"))] - fn get_fd(attempts: usize) -> libc::c_int { + fn get_fd(attempts: usize) -> io::Result { for i in 0..attempts { - let fd = unsafe { - let file_name = format!("/dev/bpf{}", i); - let c_file_name = CString::new(file_name.as_bytes()).unwrap(); - libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) - }; + let file_name = format!("/dev/bpf{}", i); + let c_file_name = CString::new(file_name.as_bytes()).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidInput, "invalid bpf device path") + })?; + let fd = unsafe { libc::open(c_file_name.as_ptr(), libc::O_RDWR, 0) }; if fd != -1 { - return fd; + return Ok(fd); } } - - -1 + Err(io::Error::last_os_error()) } #[cfg(any( @@ -117,10 +122,7 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result io::Result<()> { + if self.write_buffer_size == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "write_buffer_size must be greater than 0", + )); + } + if self.read_buffer_size == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "read_buffer_size must be greater than 0", + )); + } + if self.bpf_fd_attempts == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bpf_fd_attempts must be greater than 0", + )); + } + Ok(()) + } + + pub fn with_write_buffer_size(mut self, write_buffer_size: usize) -> Self { + self.write_buffer_size = write_buffer_size; + self + } + + pub fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self { + self.read_buffer_size = read_buffer_size; + self + } + + pub fn with_read_timeout(mut self, read_timeout: Option) -> Self { + self.read_timeout = read_timeout; + self + } + + pub fn with_write_timeout(mut self, write_timeout: Option) -> Self { + self.write_timeout = write_timeout; + self + } + + pub fn with_channel_type(mut self, channel_type: ChannelType) -> Self { + self.channel_type = channel_type; + self + } + + pub fn with_bpf_fd_attempts(mut self, bpf_fd_attempts: usize) -> Self { + self.bpf_fd_attempts = bpf_fd_attempts; + self + } + + pub fn with_linux_fanout(mut self, linux_fanout: Option) -> Self { + self.linux_fanout = linux_fanout; + self + } + + pub fn with_promiscuous(mut self, promiscuous: bool) -> Self { + self.promiscuous = promiscuous; + self + } +} + /// Creates a new datalink channel for sending and receiving raw packets. /// /// This function sets up a channel to send and receive raw packets directly from a data link layer @@ -150,6 +215,7 @@ pub fn channel( network_interface: &nex_core::interface::Interface, configuration: Config, ) -> io::Result { + configuration.validate()?; backend::channel(network_interface, (&configuration).into()) } @@ -198,4 +264,25 @@ mod tests { assert!(cfg.linux_fanout.is_none()); assert!(cfg.promiscuous); } + + #[test] + fn config_validate_rejects_zero_buffer_size() { + let cfg = Config::default().with_read_buffer_size(0); + assert!(cfg.validate().is_err()); + + let cfg = Config::default().with_write_buffer_size(0); + assert!(cfg.validate().is_err()); + } + + #[test] + fn config_builder_updates_fields() { + let cfg = Config::default() + .with_channel_type(ChannelType::Layer3(0x0800)) + .with_promiscuous(false) + .with_bpf_fd_attempts(42); + + assert_eq!(cfg.channel_type, ChannelType::Layer3(0x0800)); + assert!(!cfg.promiscuous); + assert_eq!(cfg.bpf_fd_attempts, 42); + } } diff --git a/nex-datalink/src/pcap.rs b/nex-datalink/src/pcap.rs index 971ecac..37b84c6 100644 --- a/nex-datalink/src/pcap.rs +++ b/nex-datalink/src/pcap.rs @@ -4,7 +4,7 @@ use std::io; use std::marker::{Send, Sync}; use std::path::Path; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use std::time::Duration; use pcap::{Activated, Active}; @@ -110,6 +110,14 @@ struct RawSenderImpl { capture: Arc>>, } +fn lock_capture( + capture: &Mutex>, +) -> io::Result>> { + capture + .lock() + .map_err(|_| io::Error::new(io::ErrorKind::Other, "pcap capture mutex poisoned")) +} + impl RawSender for RawSenderImpl { #[inline] fn build_and_send( @@ -121,7 +129,10 @@ impl RawSender for RawSenderImpl { for _ in 0..num_packets { let mut data = vec![0; packet_size]; func(&mut data); - let mut cap = self.capture.lock().unwrap(); + let mut cap = match lock_capture(&self.capture) { + Ok(cap) => cap, + Err(err) => return Some(Err(err)), + }; if let Err(e) = cap.sendpacket(data) { return Some(Err(io::Error::new(io::ErrorKind::Other, e))); } @@ -131,7 +142,10 @@ impl RawSender for RawSenderImpl { #[inline] fn send(&mut self, packet: &[u8]) -> Option> { - let mut cap = self.capture.lock().unwrap(); + let mut cap = match lock_capture(&self.capture) { + Ok(cap) => cap, + Err(err) => return Some(Err(err)), + }; Some(match cap.sendpacket(packet) { Ok(()) => Ok(()), Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), @@ -165,7 +179,7 @@ struct RawReceiverImpl { impl RawReceiver for RawReceiverImpl { fn next(&mut self) -> io::Result<&[u8]> { - let mut cap = self.capture.lock().unwrap(); + let mut cap = lock_capture(&self.capture)?; match cap.next_packet() { Ok(pkt) => { self.read_buffer.truncate(0); diff --git a/nex-datalink/src/wpcap.rs b/nex-datalink/src/wpcap.rs index 393b052..b46a650 100644 --- a/nex-datalink/src/wpcap.rs +++ b/nex-datalink/src/wpcap.rs @@ -76,7 +76,9 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result MutableArpPacket<'a> { } pub fn get_sender_hw_addr(&self) -> MacAddr { - MacAddr::from_octets(self.raw()[8..14].try_into().unwrap()) + let raw = self.raw(); + MacAddr::from_octets([raw[8], raw[9], raw[10], raw[11], raw[12], raw[13]]) } pub fn set_sender_hw_addr(&mut self, addr: MacAddr) { @@ -567,7 +572,8 @@ impl<'a> MutableArpPacket<'a> { } pub fn get_target_hw_addr(&self) -> MacAddr { - MacAddr::from_octets(self.raw()[18..24].try_into().unwrap()) + let raw = self.raw(); + MacAddr::from_octets([raw[18], raw[19], raw[20], raw[21], raw[22], raw[23]]) } pub fn set_target_hw_addr(&mut self, addr: MacAddr) { diff --git a/nex-packet/src/ethernet.rs b/nex-packet/src/ethernet.rs index 08a157f..caecaf2 100644 --- a/nex-packet/src/ethernet.rs +++ b/nex-packet/src/ethernet.rs @@ -197,9 +197,10 @@ impl Packet for EthernetPacket { if bytes.len() < ETHERNET_HEADER_LEN { return None; } - let destination = MacAddr::from_octets(bytes[0..MAC_ADDR_LEN].try_into().unwrap()); + let destination = + MacAddr::from_octets([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]); let source = - MacAddr::from_octets(bytes[MAC_ADDR_LEN..2 * MAC_ADDR_LEN].try_into().unwrap()); + MacAddr::from_octets([bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11]]); let ethertype = EtherType::new(u16::from_be_bytes([bytes[12], bytes[13]])); let payload = Bytes::copy_from_slice(&bytes[ETHERNET_HEADER_LEN..]); @@ -335,7 +336,8 @@ impl<'a> MutableEthernetPacket<'a> { /// Retrieve the destination MAC address. pub fn get_destination(&self) -> MacAddr { - MacAddr::from_octets(self.header()[0..MAC_ADDR_LEN].try_into().unwrap()) + let h = self.header(); + MacAddr::from_octets([h[0], h[1], h[2], h[3], h[4], h[5]]) } /// Update the destination MAC address. @@ -345,11 +347,8 @@ impl<'a> MutableEthernetPacket<'a> { /// Retrieve the source MAC address. pub fn get_source(&self) -> MacAddr { - MacAddr::from_octets( - self.header()[MAC_ADDR_LEN..2 * MAC_ADDR_LEN] - .try_into() - .unwrap(), - ) + let h = self.header(); + MacAddr::from_octets([h[6], h[7], h[8], h[9], h[10], h[11]]) } /// Update the source MAC address. diff --git a/nex-packet/src/gre.rs b/nex-packet/src/gre.rs index 2100a7b..de3f902 100644 --- a/nex-packet/src/gre.rs +++ b/nex-packet/src/gre.rs @@ -76,8 +76,8 @@ impl Packet for GrePacket { } if routing_present != 0 { - // Not implemented for this crate - panic!("Source routed GRE packets not supported"); + // Source-routed GRE parsing is not yet supported. + return None; } let payload = Bytes::copy_from_slice(bytes); @@ -145,9 +145,8 @@ impl Packet for GrePacket { } } - // Panic if routing_present is set (not supported by this implementation) if self.routing_present != 0 { - panic!("to_bytes does not support source routed GRE packets"); + buf.put_slice(&self.routing); } buf.put_slice(&self.payload); @@ -194,9 +193,8 @@ impl Packet for GrePacket { } } - // Panic if routing_present is set (not supported by this implementation) if self.routing_present != 0 { - panic!("header does not support source routed GRE packets"); + buf.put_slice(&self.routing); } buf.freeze() @@ -212,6 +210,7 @@ impl Packet for GrePacket { + self.offset_length() + self.key_length() + self.sequence_length() + + self.routing_length() } fn payload_len(&self) -> usize { @@ -248,7 +247,7 @@ impl GrePacket { if 0 == self.routing_present { 0 } else { - panic!("Source routed GRE packets not supported") + self.routing.len() } } } @@ -308,4 +307,51 @@ mod tests { assert_eq!(frozen.protocol_type, 0x86dd); assert_eq!(frozen.payload[0], 0xff); } + + #[test] + fn gre_with_routing_present_is_not_parsed() { + let packet = Bytes::from_static(&[ + 0x40, 0x00, // routing flag on + 0x08, 0x00, // protocol type + 0x00, 0x00, // checksum + 0x00, 0x00, // offset + 0xaa, 0xbb, // routing data (unsupported) + 0xcc, 0xdd, // payload + ]); + + assert!(GrePacket::from_buf(&packet).is_none()); + } + + #[test] + fn gre_to_bytes_with_routing_present_does_not_panic() { + let packet = GrePacket { + checksum_present: 0, + routing_present: 1, + key_present: 0, + sequence_present: 0, + strict_source_route: 0, + recursion_control: 0, + zero_flags: 0, + version: 0, + protocol_type: 0x0800, + checksum: vec![0x1111], + offset: vec![0x2222], + key: vec![], + sequence: vec![], + routing: vec![0xaa, 0xbb], + payload: Bytes::from_static(&[0xcc]), + }; + + let bytes = packet.to_bytes(); + assert_eq!( + bytes.as_ref(), + &[ + 0x40, 0x00, 0x08, 0x00, 0x11, 0x11, 0x22, 0x22, 0xaa, 0xbb, 0xcc + ] + ); + assert_eq!( + packet.header().as_ref(), + &[0x40, 0x00, 0x08, 0x00, 0x11, 0x11, 0x22, 0x22, 0xaa, 0xbb] + ); + } } diff --git a/nex-packet/src/ipv4.rs b/nex-packet/src/ipv4.rs index 81ce672..c2c6c54 100644 --- a/nex-packet/src/ipv4.rs +++ b/nex-packet/src/ipv4.rs @@ -350,16 +350,6 @@ impl Packet for Ipv4Packet { let header_len = IPV4_HEADER_LEN + tmp_buf.len(); let total_len_expected = header_len + self.payload.len(); - // Check if the total length exceeds the header's total_length field - if total_len_expected > self.header.total_length as usize { - panic!( - "Payload too long: header {} + payload {} = {} > total_length {}", - header_len, - self.payload.len(), - total_len_expected, - self.header.total_length - ); - } let header_len_words = (header_len / 4) as u8; @@ -369,7 +359,9 @@ impl Packet for Ipv4Packet { buf.put_u8((self.header.dscp << 2 | self.header.ecn) as u8); // 2. Fixed header fields - buf.put_u16(self.header.total_length); + // Keep header total length consistent with the actual serialized packet length. + let total_length = total_len_expected.min(u16::MAX as usize) as u16; + buf.put_u16(total_length); buf.put_u16(self.header.identification); buf.put_u16(((self.header.flags as u16) << 13) | self.header.fragment_offset); buf.put_u8(self.header.ttl); @@ -864,8 +856,7 @@ mod tests { } #[test] - #[should_panic(expected = "Payload too long")] - fn ipv4_payload_too_long_should_panic() { + fn ipv4_payload_too_long_updates_total_length() { let packet = Ipv4Packet { header: Ipv4Header { version: 4, @@ -886,8 +877,9 @@ mod tests { payload: Bytes::from_static(&[0, 1, 2, 3, 4, 5]), // 6 bytes payload }; - // This should panic because the payload length exceeds the total_length specified in the header - let _ = packet.to_bytes(); + let bytes = packet.to_bytes(); + let reparsed = Ipv4Packet::from_bytes(bytes).expect("reparse"); + assert_eq!(reparsed.header.total_length as usize, reparsed.total_len()); } #[test] diff --git a/nex-packet/src/ipv6.rs b/nex-packet/src/ipv6.rs index 568c600..ff4ce05 100644 --- a/nex-packet/src/ipv6.rs +++ b/nex-packet/src/ipv6.rs @@ -419,7 +419,11 @@ impl<'a> MutableIpv6Packet<'a> { } pub fn get_source(&self) -> Ipv6Addr { - Ipv6Addr::from(<[u8; 16]>::try_from(&self.raw()[8..24]).unwrap()) + let raw = self.raw(); + Ipv6Addr::from([ + raw[8], raw[9], raw[10], raw[11], raw[12], raw[13], raw[14], raw[15], raw[16], raw[17], + raw[18], raw[19], raw[20], raw[21], raw[22], raw[23], + ]) } pub fn set_source(&mut self, addr: Ipv6Addr) { @@ -427,7 +431,11 @@ impl<'a> MutableIpv6Packet<'a> { } pub fn get_destination(&self) -> Ipv6Addr { - Ipv6Addr::from(<[u8; 16]>::try_from(&self.raw()[24..40]).unwrap()) + let raw = self.raw(); + Ipv6Addr::from([ + raw[24], raw[25], raw[26], raw[27], raw[28], raw[29], raw[30], raw[31], raw[32], + raw[33], raw[34], raw[35], raw[36], raw[37], raw[38], raw[39], + ]) } pub fn set_destination(&mut self, addr: Ipv6Addr) { diff --git a/nex-packet/src/lib.rs b/nex-packet/src/lib.rs index 7903d1c..36be4df 100644 --- a/nex-packet/src/lib.rs +++ b/nex-packet/src/lib.rs @@ -1,4 +1,4 @@ -//! Packet parsing and construction utilities for common network protocols. +//! Low-level packet parsing and serialization primitives for common network protocols. pub mod arp; pub mod builder; diff --git a/nex-packet/src/packet.rs b/nex-packet/src/packet.rs index 9bf1296..bf511ef 100644 --- a/nex-packet/src/packet.rs +++ b/nex-packet/src/packet.rs @@ -27,6 +27,11 @@ pub trait Packet: Sized { fn payload_len(&self) -> usize; /// Get the total length of the packet (header + payload). fn total_len(&self) -> usize; + + /// Returns true when the serialized packet is empty. + fn is_empty(&self) -> bool { + self.total_len() == 0 + } /// Convert the packet to a mutable byte buffer. fn to_bytes_mut(&self) -> BytesMut { let mut buf = BytesMut::with_capacity(self.total_len()); @@ -78,6 +83,11 @@ pub trait MutablePacket<'a>: Sized { /// Get a mutable view over the payload bytes of the packet. fn payload_mut(&mut self) -> &mut [u8]; + /// Returns true when the packet buffer is empty. + fn is_empty(&self) -> bool { + self.packet().is_empty() + } + /// Convert the mutable packet into its immutable counterpart. fn freeze(&self) -> Option { Self::Packet::from_buf(self.packet()) diff --git a/nex-packet/src/util.rs b/nex-packet/src/util.rs index 32ed338..9079ee1 100644 --- a/nex-packet/src/util.rs +++ b/nex-packet/src/util.rs @@ -3,7 +3,6 @@ use crate::ip::IpNextProtocol; use nex_core::bitfield::u16be; -use core::convert::TryInto; use core::u8; use core::u16; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -153,8 +152,7 @@ fn sum_be_words(data: &[u8], skipword: usize) -> u32 { let mut i = 0; while cur_data.len() >= 2 { if i != skipword { - // It's safe to unwrap because we verified there are at least 2 bytes - sum += u16::from_be_bytes(cur_data[0..2].try_into().unwrap()) as u32; + sum += ((cur_data[0] as u32) << 8) + cur_data[1] as u32; } cur_data = &cur_data[2..]; i += 1; diff --git a/nex-socket/Cargo.toml b/nex-socket/Cargo.toml index b12a51e..4121053 100644 --- a/nex-socket/Cargo.toml +++ b/nex-socket/Cargo.toml @@ -18,7 +18,7 @@ tokio = { version = "1", features = ["time", "sync", "net", "rt"] } libc = { workspace = true } [target.'cfg(unix)'.dependencies] -nix = { version = "0.30", features = ["poll"] } +nix = { version = "0.30", features = ["poll", "net", "uio"] } [target.'cfg(windows)'.dependencies.windows-sys] version = "0.59.0" diff --git a/nex-socket/src/icmp/async_impl.rs b/nex-socket/src/icmp/async_impl.rs index a0d121e..9a34863 100644 --- a/nex-socket/src/icmp/async_impl.rs +++ b/nex-socket/src/icmp/async_impl.rs @@ -86,7 +86,7 @@ impl AsyncIcmpSocket { Ok(Self { inner, - socket_type: IcmpSocketType::from_sock_type(sock_type), + socket_type: IcmpSocketType::try_from_sock_type(sock_type)?, socket_family: config.socket_family, }) } diff --git a/nex-socket/src/icmp/config.rs b/nex-socket/src/icmp/config.rs index 3583f29..788ed46 100644 --- a/nex-socket/src/icmp/config.rs +++ b/nex-socket/src/icmp/config.rs @@ -1,5 +1,5 @@ use socket2::Type as SockType; -use std::{net::SocketAddr, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use crate::SocketFamily; @@ -29,11 +29,14 @@ impl IcmpSocketType { } /// Converts the ICMP socket type from a `socket2::Type`. - pub(crate) fn from_sock_type(sock_type: SockType) -> Self { + pub(crate) fn try_from_sock_type(sock_type: SockType) -> io::Result { match sock_type { - SockType::DGRAM => IcmpSocketType::Dgram, - SockType::RAW => IcmpSocketType::Raw, - _ => panic!("Invalid ICMP socket type"), + SockType::DGRAM => Ok(IcmpSocketType::Dgram), + SockType::RAW => Ok(IcmpSocketType::Raw), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid ICMP socket type", + )), } } @@ -88,6 +91,17 @@ impl IcmpConfig { } } + /// Creates a new ICMP configuration from a socket family. + pub fn from_family(socket_family: SocketFamily) -> Self { + Self { + socket_family, + ..Self::new(match socket_family { + SocketFamily::IPV4 => IcmpKind::V4, + SocketFamily::IPV6 => IcmpKind::V6, + }) + } + } + /// Set bind address for the socket. pub fn with_bind(mut self, addr: SocketAddr) -> Self { self.bind = Some(addr); @@ -106,6 +120,11 @@ impl IcmpConfig { self } + /// Set the hop limit for IPv6 packets. + pub fn with_hop_limit(self, hops: u32) -> Self { + self.with_hoplimit(hops) + } + /// Set the read timeout for the socket. pub fn with_read_timeout(mut self, timeout: Duration) -> Self { self.read_timeout = Some(timeout); @@ -140,6 +159,7 @@ impl IcmpConfig { #[cfg(test)] mod tests { use super::*; + #[test] fn icmp_config_builders() { let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); @@ -154,4 +174,12 @@ mod tests { assert_eq!(cfg.interface.as_deref(), Some("eth0")); assert_eq!(cfg.sock_type_hint, IcmpSocketType::Raw); } + + #[test] + fn from_family_sets_expected_kind() { + let v4 = IcmpConfig::from_family(SocketFamily::IPV4); + let v6 = IcmpConfig::from_family(SocketFamily::IPV6); + assert_eq!(v4.socket_family, SocketFamily::IPV4); + assert_eq!(v6.socket_family, SocketFamily::IPV6); + } } diff --git a/nex-socket/src/icmp/sync_impl.rs b/nex-socket/src/icmp/sync_impl.rs index 46e4105..fcd2eaf 100644 --- a/nex-socket/src/icmp/sync_impl.rs +++ b/nex-socket/src/icmp/sync_impl.rs @@ -70,7 +70,7 @@ impl IcmpSocket { Ok(Self { inner: std_socket, - socket_type: IcmpSocketType::from_sock_type(sock_type), + socket_type: IcmpSocketType::try_from_sock_type(sock_type)?, socket_family: config.socket_family, }) } diff --git a/nex-socket/src/lib.rs b/nex-socket/src/lib.rs index 4849a8b..a033a53 100644 --- a/nex-socket/src/lib.rs +++ b/nex-socket/src/lib.rs @@ -1,8 +1,7 @@ -//! Convenience sockets built on top of `socket2` and `tokio`. +//! Cross-platform low-level socket APIs for TCP, UDP and ICMP. //! -//! This crate provides synchronous and asynchronous helpers for TCP, UDP and -//! ICMP. The goal is to simplify lower level socket configuration across -//! platforms while still allowing direct access when needed. +//! `nex-socket` focuses on predictable, low-level behavior and platform-aware +//! socket option control. pub mod icmp; pub mod tcp; diff --git a/nex-socket/src/tcp/async_impl.rs b/nex-socket/src/tcp/async_impl.rs index e160644..06178a2 100644 --- a/nex-socket/src/tcp/async_impl.rs +++ b/nex-socket/src/tcp/async_impl.rs @@ -26,6 +26,23 @@ impl AsyncTcpSocket { if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(flag) = config.reuseport { + socket.set_reuse_port(flag)?; + } if let Some(flag) = config.nodelay { socket.set_nodelay(flag)?; } @@ -44,6 +61,35 @@ impl AsyncTcpSocket { if let Some(timeout) = config.write_timeout { socket.set_write_timeout(Some(timeout))?; } + if let Some(size) = config.recv_buffer_size { + socket.set_recv_buffer_size(size)?; + } + if let Some(size) = config.send_buffer_size { + socket.set_send_buffer_size(size)?; + } + if let Some(tos) = config.tos { + socket.set_tos(tos)?; + } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(tclass) = config.tclass_v6 { + socket.set_tclass_v6(tclass)?; + } + if let Some(only_v6) = config.only_v6 { + socket.set_only_v6(only_v6)?; + } // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -173,11 +219,59 @@ impl AsyncTcpSocket { self.socket.set_reuse_address(on) } + /// Get reuse address option. + pub fn reuseaddr(&self) -> io::Result { + self.socket.reuse_address() + } + + /// Set port reuse option where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_reuseport(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_port(on) + } + + /// Get port reuse option where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn reuseport(&self) -> io::Result { + self.socket.reuse_port() + } + /// Set no delay option for TCP. pub fn set_nodelay(&self, on: bool) -> io::Result<()> { self.socket.set_nodelay(on) } + /// Get no delay option for TCP. + pub fn nodelay(&self) -> io::Result { + self.socket.nodelay() + } + /// Set linger option for the socket. pub fn set_linger(&self, dur: Option) -> io::Result<()> { self.socket.set_linger(dur) @@ -188,16 +282,109 @@ impl AsyncTcpSocket { self.socket.set_ttl(ttl) } + /// Get the time-to-live for IPv4 packets. + pub fn ttl(&self) -> io::Result { + self.socket.ttl() + } + /// Set the hop limit for IPv6 packets. pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { self.socket.set_unicast_hops_v6(hops) } + /// Get the hop limit for IPv6 packets. + pub fn hoplimit(&self) -> io::Result { + self.socket.unicast_hops_v6() + } + /// Set the keepalive option for the socket. pub fn set_keepalive(&self, on: bool) -> io::Result<()> { self.socket.set_keepalive(on) } + /// Get the keepalive option for the socket. + pub fn keepalive(&self) -> io::Result { + self.socket.keepalive() + } + + /// Set the receive buffer size. + pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_recv_buffer_size(size) + } + + /// Get the receive buffer size. + pub fn recv_buffer_size(&self) -> io::Result { + self.socket.recv_buffer_size() + } + + /// Set the send buffer size. + pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_send_buffer_size(size) + } + + /// Get the send buffer size. + pub fn send_buffer_size(&self) -> io::Result { + self.socket.send_buffer_size() + } + + /// Set IPv4 TOS / DSCP. + pub fn set_tos(&self, tos: u32) -> io::Result<()> { + self.socket.set_tos(tos) + } + + /// Get IPv4 TOS / DSCP. + pub fn tos(&self) -> io::Result { + self.socket.tos() + } + + /// Set IPv6 traffic class where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> { + self.socket.set_tclass_v6(tclass) + } + + /// Get IPv6 traffic class where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn tclass_v6(&self) -> io::Result { + self.socket.tclass_v6() + } + + /// Set whether this socket is IPv6 only. + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + self.socket.set_only_v6(only_v6) + } + + /// Get whether this socket is IPv6 only. + pub fn only_v6(&self) -> io::Result { + self.socket.only_v6() + } + /// Set the bind device for the socket (Linux specific). pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -208,7 +395,7 @@ impl AsyncTcpSocket { let _ = iface; Err(io::Error::new( io::ErrorKind::Unsupported, - "bind_device not supported on this OS", + "bind_device is not supported on this platform", )) } } @@ -218,7 +405,7 @@ impl AsyncTcpSocket { self.socket .local_addr()? .as_socket() - .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address")) + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address")) } /// Convert the internal socket into a Tokio `TcpStream`. @@ -227,6 +414,21 @@ impl AsyncTcpSocket { TcpStream::from_std(std_stream) } + /// Construct from a raw `socket2::Socket`. + pub fn from_socket(socket: Socket) -> Self { + Self { socket } + } + + /// Borrow the inner `socket2::Socket`. + pub fn socket(&self) -> &Socket { + &self.socket + } + + /// Consume and return the inner `socket2::Socket`. + pub fn into_socket(self) -> Socket { + self.socket + } + /// Extract the RAW file descriptor for Unix. #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { diff --git a/nex-socket/src/tcp/config.rs b/nex-socket/src/tcp/config.rs index a836709..2170c7e 100644 --- a/nex-socket/src/tcp/config.rs +++ b/nex-socket/src/tcp/config.rs @@ -44,6 +44,8 @@ pub struct TcpConfig { pub nonblocking: bool, /// Whether to allow address reuse. pub reuseaddr: Option, + /// Whether to allow port reuse (`SO_REUSEPORT`) where supported. + pub reuseport: Option, /// Whether to disable Nagle's algorithm (TCP_NODELAY). pub nodelay: Option, /// Optional linger duration for the socket. @@ -56,6 +58,16 @@ pub struct TcpConfig { pub read_timeout: Option, /// Optional write timeout for the socket. pub write_timeout: Option, + /// Optional receive buffer size in bytes. + pub recv_buffer_size: Option, + /// Optional send buffer size in bytes. + pub send_buffer_size: Option, + /// Optional IPv4 TOS / DSCP field value. + pub tos: Option, + /// Optional IPv6 traffic class value (`IPV6_TCLASS`) where supported. + pub tclass_v6: Option, + /// Whether to force IPv6-only behavior on dual-stack sockets. + pub only_v6: Option, /// Optional device to bind the socket to. pub bind_device: Option, /// Whether to enable TCP keepalive. @@ -63,6 +75,14 @@ pub struct TcpConfig { } impl TcpConfig { + /// Create a STREAM socket for the specified family. + pub fn new(socket_family: SocketFamily) -> Self { + match socket_family { + SocketFamily::IPV4 => Self::v4_stream(), + SocketFamily::IPV6 => Self::v6_stream(), + } + } + /// Create a STREAM socket for IPv4. pub fn v4_stream() -> Self { Self { @@ -71,12 +91,18 @@ impl TcpConfig { bind_addr: None, nonblocking: false, reuseaddr: None, + reuseport: None, nodelay: None, linger: None, ttl: None, hoplimit: None, read_timeout: None, write_timeout: None, + recv_buffer_size: None, + send_buffer_size: None, + tos: None, + tclass_v6: None, + only_v6: None, bind_device: None, keepalive: None, } @@ -116,6 +142,10 @@ impl TcpConfig { self } + pub fn with_bind_addr(self, addr: SocketAddr) -> Self { + self.with_bind(addr) + } + pub fn with_nonblocking(mut self, flag: bool) -> Self { self.nonblocking = flag; self @@ -126,6 +156,11 @@ impl TcpConfig { self } + pub fn with_reuseport(mut self, flag: bool) -> Self { + self.reuseport = Some(flag); + self + } + pub fn with_nodelay(mut self, flag: bool) -> Self { self.nodelay = Some(flag); self @@ -146,6 +181,10 @@ impl TcpConfig { self } + pub fn with_hop_limit(self, hops: u32) -> Self { + self.with_hoplimit(hops) + } + pub fn with_keepalive(mut self, on: bool) -> Self { self.keepalive = Some(on); self @@ -161,6 +200,31 @@ impl TcpConfig { self } + pub fn with_recv_buffer_size(mut self, size: usize) -> Self { + self.recv_buffer_size = Some(size); + self + } + + pub fn with_send_buffer_size(mut self, size: usize) -> Self { + self.send_buffer_size = Some(size); + self + } + + pub fn with_tos(mut self, tos: u32) -> Self { + self.tos = Some(tos); + self + } + + pub fn with_tclass_v6(mut self, tclass: u32) -> Self { + self.tclass_v6 = Some(tclass); + self + } + + pub fn with_only_v6(mut self, only_v6: bool) -> Self { + self.only_v6 = Some(only_v6); + self + } + pub fn with_bind_device(mut self, iface: impl Into) -> Self { self.bind_device = Some(iface.into()); self @@ -174,19 +238,36 @@ mod tests { #[test] fn tcp_config_builders() { let addr: SocketAddr = "127.0.0.1:80".parse().unwrap(); - let cfg = TcpConfig::v4_stream() - .with_bind(addr) + let cfg = TcpConfig::new(SocketFamily::IPV4) + .with_bind_addr(addr) .with_nonblocking(true) .with_reuseaddr(true) + .with_reuseport(true) .with_nodelay(true) - .with_ttl(10); + .with_ttl(10) + .with_recv_buffer_size(8192) + .with_send_buffer_size(8192) + .with_tos(0x10) + .with_tclass_v6(0x20); assert_eq!(cfg.socket_family, SocketFamily::IPV4); assert_eq!(cfg.socket_type, TcpSocketType::Stream); assert_eq!(cfg.bind_addr, Some(addr)); assert!(cfg.nonblocking); assert_eq!(cfg.reuseaddr, Some(true)); + assert_eq!(cfg.reuseport, Some(true)); assert_eq!(cfg.nodelay, Some(true)); assert_eq!(cfg.ttl, Some(10)); + assert_eq!(cfg.recv_buffer_size, Some(8192)); + assert_eq!(cfg.send_buffer_size, Some(8192)); + assert_eq!(cfg.tos, Some(0x10)); + assert_eq!(cfg.tclass_v6, Some(0x20)); + } + + #[test] + fn new_with_ipv6_family_creates_v6_stream() { + let cfg = TcpConfig::new(SocketFamily::IPV6); + assert_eq!(cfg.socket_family, SocketFamily::IPV6); + assert_eq!(cfg.socket_type, TcpSocketType::Stream); } } diff --git a/nex-socket/src/tcp/sync_impl.rs b/nex-socket/src/tcp/sync_impl.rs index 3d88526..0edd329 100644 --- a/nex-socket/src/tcp/sync_impl.rs +++ b/nex-socket/src/tcp/sync_impl.rs @@ -32,6 +32,23 @@ impl TcpSocket { if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(flag) = config.reuseport { + socket.set_reuse_port(flag)?; + } if let Some(flag) = config.nodelay { socket.set_nodelay(flag)?; } @@ -53,6 +70,35 @@ impl TcpSocket { if let Some(timeout) = config.write_timeout { socket.set_write_timeout(Some(timeout))?; } + if let Some(size) = config.recv_buffer_size { + socket.set_recv_buffer_size(size)?; + } + if let Some(size) = config.send_buffer_size { + socket.set_send_buffer_size(size)?; + } + if let Some(tos) = config.tos { + socket.set_tos(tos)?; + } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(tclass) = config.tclass_v6 { + socket.set_tclass_v6(tclass)?; + } + if let Some(only_v6) = config.only_v6 { + socket.set_only_v6(only_v6)?; + } // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -269,11 +315,59 @@ impl TcpSocket { self.socket.set_reuse_address(on) } + /// Get the socket address reuse option. + pub fn reuseaddr(&self) -> io::Result { + self.socket.reuse_address() + } + + /// Set the socket port reuse option where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_reuseport(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_port(on) + } + + /// Get the socket port reuse option where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn reuseport(&self) -> io::Result { + self.socket.reuse_port() + } + /// Set the socket to not delay packets. pub fn set_nodelay(&self, on: bool) -> io::Result<()> { self.socket.set_nodelay(on) } + /// Get the no delay option. + pub fn nodelay(&self) -> io::Result { + self.socket.nodelay() + } + /// Set the linger option for the socket. pub fn set_linger(&self, dur: Option) -> io::Result<()> { self.socket.set_linger(dur) @@ -284,16 +378,109 @@ impl TcpSocket { self.socket.set_ttl(ttl) } + /// Get the time-to-live for IPv4 packets. + pub fn ttl(&self) -> io::Result { + self.socket.ttl() + } + /// Set the hop limit for IPv6 packets. pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { self.socket.set_unicast_hops_v6(hops) } + /// Get the hop limit for IPv6 packets. + pub fn hoplimit(&self) -> io::Result { + self.socket.unicast_hops_v6() + } + /// Set the keepalive option for the socket. pub fn set_keepalive(&self, on: bool) -> io::Result<()> { self.socket.set_keepalive(on) } + /// Get the keepalive option. + pub fn keepalive(&self) -> io::Result { + self.socket.keepalive() + } + + /// Set the receive buffer size. + pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_recv_buffer_size(size) + } + + /// Get the receive buffer size. + pub fn recv_buffer_size(&self) -> io::Result { + self.socket.recv_buffer_size() + } + + /// Set the send buffer size. + pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_send_buffer_size(size) + } + + /// Get the send buffer size. + pub fn send_buffer_size(&self) -> io::Result { + self.socket.send_buffer_size() + } + + /// Set IPv4 TOS / DSCP. + pub fn set_tos(&self, tos: u32) -> io::Result<()> { + self.socket.set_tos(tos) + } + + /// Get IPv4 TOS / DSCP. + pub fn tos(&self) -> io::Result { + self.socket.tos() + } + + /// Set IPv6 traffic class where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> { + self.socket.set_tclass_v6(tclass) + } + + /// Get IPv6 traffic class where supported. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn tclass_v6(&self) -> io::Result { + self.socket.tclass_v6() + } + + /// Set whether this socket is IPv6 only. + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + self.socket.set_only_v6(only_v6) + } + + /// Get whether this socket is IPv6 only. + pub fn only_v6(&self) -> io::Result { + self.socket.only_v6() + } + /// Set the bind device for the socket (Linux specific). pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -304,7 +491,7 @@ impl TcpSocket { let _ = iface; Err(io::Error::new( io::ErrorKind::Unsupported, - "bind_device not supported on this OS", + "bind_device is not supported on this platform", )) } } @@ -314,7 +501,7 @@ impl TcpSocket { self.socket .local_addr()? .as_socket() - .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address")) + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address")) } /// Extract the RAW file descriptor for Unix. @@ -330,4 +517,19 @@ impl TcpSocket { use std::os::windows::io::AsRawSocket; self.socket.as_raw_socket() } + + /// Construct from a raw `socket2::Socket`. + pub fn from_socket(socket: Socket) -> Self { + Self { socket } + } + + /// Borrow the inner `socket2::Socket`. + pub fn socket(&self) -> &Socket { + &self.socket + } + + /// Consume and return the inner `socket2::Socket`. + pub fn into_socket(self) -> Socket { + self.socket + } } diff --git a/nex-socket/src/udp/async_impl.rs b/nex-socket/src/udp/async_impl.rs index 5db2226..a889786 100644 --- a/nex-socket/src/udp/async_impl.rs +++ b/nex-socket/src/udp/async_impl.rs @@ -25,6 +25,23 @@ impl AsyncUdpSocket { if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(flag) = config.reuseport { + socket.set_reuse_port(flag)?; + } if let Some(flag) = config.broadcast { socket.set_broadcast(flag)?; } @@ -40,6 +57,38 @@ impl AsyncUdpSocket { if let Some(timeout) = config.write_timeout { socket.set_write_timeout(Some(timeout))?; } + if let Some(size) = config.recv_buffer_size { + socket.set_recv_buffer_size(size)?; + } + if let Some(size) = config.send_buffer_size { + socket.set_send_buffer_size(size)?; + } + if let Some(tos) = config.tos { + socket.set_tos(tos)?; + } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(tclass) = config.tclass_v6 { + socket.set_tclass_v6(tclass)?; + } + if let Some(only_v6) = config.only_v6 { + socket.set_only_v6(only_v6)?; + } + if let Some(on) = config.recv_pktinfo { + crate::udp::set_recv_pktinfo(&socket, config.socket_family, on)?; + } // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -128,6 +177,38 @@ impl AsyncUdpSocket { Ok(self.inner) } + /// Construct from a standard UDP socket. + pub fn from_std_socket(socket: StdUdpSocket) -> io::Result { + Ok(Self { + inner: UdpSocket::from_std(socket)?, + }) + } + + /// Convert into a standard UDP socket. + pub fn into_std_socket(self) -> io::Result { + self.inner.into_std() + } + + /// Set IPv4 time-to-live. + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.inner.set_ttl(ttl) + } + + /// Get IPv4 time-to-live. + pub fn ttl(&self) -> io::Result { + self.inner.ttl() + } + + /// Set broadcast mode. + pub fn set_broadcast(&self, on: bool) -> io::Result<()> { + self.inner.set_broadcast(on) + } + + /// Get broadcast mode. + pub fn broadcast(&self) -> io::Result { + self.inner.broadcast() + } + #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; diff --git a/nex-socket/src/udp/config.rs b/nex-socket/src/udp/config.rs index f1e4146..e369c3d 100644 --- a/nex-socket/src/udp/config.rs +++ b/nex-socket/src/udp/config.rs @@ -42,6 +42,8 @@ pub struct UdpConfig { pub bind_addr: Option, /// Enable address reuse (`SO_REUSEADDR`). pub reuseaddr: Option, + /// Whether to allow port reuse (`SO_REUSEPORT`) where supported. + pub reuseport: Option, /// Allow broadcast (`SO_BROADCAST`). pub broadcast: Option, /// Time to live value. @@ -52,6 +54,18 @@ pub struct UdpConfig { pub read_timeout: Option, /// Write timeout for the socket. pub write_timeout: Option, + /// Optional receive buffer size in bytes. + pub recv_buffer_size: Option, + /// Optional send buffer size in bytes. + pub send_buffer_size: Option, + /// Optional IPv4 TOS / DSCP field value. + pub tos: Option, + /// Optional IPv6 traffic class value (`IPV6_TCLASS`) where supported. + pub tclass_v6: Option, + /// Enable receiving packet info ancillary data (`IP_PKTINFO` / `IPV6_RECVPKTINFO`) where supported. + pub recv_pktinfo: Option, + /// Whether to force IPv6-only behavior on dual-stack sockets. + pub only_v6: Option, /// Bind to a specific interface (Linux only). pub bind_device: Option, } @@ -63,11 +77,18 @@ impl Default for UdpConfig { socket_type: UdpSocketType::Dgram, bind_addr: None, reuseaddr: None, + reuseport: None, broadcast: None, ttl: None, hoplimit: None, read_timeout: None, write_timeout: None, + recv_buffer_size: None, + send_buffer_size: None, + tos: None, + tclass_v6: None, + recv_pktinfo: None, + only_v6: None, bind_device: None, } } @@ -79,18 +100,43 @@ impl UdpConfig { Self::default() } + /// Create a new UDP configuration for a specific socket family. + pub fn new_with_family(socket_family: SocketFamily) -> Self { + Self { + socket_family, + ..Self::default() + } + } + + /// Set the socket family. + pub fn with_socket_family(mut self, socket_family: SocketFamily) -> Self { + self.socket_family = socket_family; + self + } + /// Set the bind address. pub fn with_bind_addr(mut self, addr: SocketAddr) -> Self { self.bind_addr = Some(addr); self } + /// Set the bind address. + pub fn with_bind(self, addr: SocketAddr) -> Self { + self.with_bind_addr(addr) + } + /// Enable address reuse. pub fn with_reuseaddr(mut self, on: bool) -> Self { self.reuseaddr = Some(on); self } + /// Enable port reuse. + pub fn with_reuseport(mut self, on: bool) -> Self { + self.reuseport = Some(on); + self + } + /// Allow broadcast. pub fn with_broadcast(mut self, on: bool) -> Self { self.broadcast = Some(on); @@ -109,6 +155,11 @@ impl UdpConfig { self } + /// Set the hop limit value. + pub fn with_hop_limit(self, hops: u32) -> Self { + self.with_hoplimit(hops) + } + /// Set the read timeout. pub fn with_read_timeout(mut self, timeout: Duration) -> Self { self.read_timeout = Some(timeout); @@ -121,6 +172,42 @@ impl UdpConfig { self } + /// Set the receive buffer size. + pub fn with_recv_buffer_size(mut self, size: usize) -> Self { + self.recv_buffer_size = Some(size); + self + } + + /// Set the send buffer size. + pub fn with_send_buffer_size(mut self, size: usize) -> Self { + self.send_buffer_size = Some(size); + self + } + + /// Set the IPv4 TOS / DSCP field value. + pub fn with_tos(mut self, tos: u32) -> Self { + self.tos = Some(tos); + self + } + + /// Set the IPv6 traffic class value. + pub fn with_tclass_v6(mut self, tclass: u32) -> Self { + self.tclass_v6 = Some(tclass); + self + } + + /// Enable packet-info ancillary data receiving. + pub fn with_recv_pktinfo(mut self, on: bool) -> Self { + self.recv_pktinfo = Some(on); + self + } + + /// Set whether the socket is IPv6 only. + pub fn with_only_v6(mut self, only_v6: bool) -> Self { + self.only_v6 = Some(only_v6); + self + } + /// Bind to a specific interface (Linux only). pub fn with_bind_device(mut self, iface: impl Into) -> Self { self.bind_device = Some(iface.into()); @@ -137,8 +224,23 @@ mod tests { let cfg = UdpConfig::default(); assert!(cfg.bind_addr.is_none()); assert!(cfg.reuseaddr.is_none()); + assert!(cfg.reuseport.is_none()); assert!(cfg.broadcast.is_none()); assert!(cfg.ttl.is_none()); + assert!(cfg.recv_buffer_size.is_none()); + assert!(cfg.send_buffer_size.is_none()); + assert!(cfg.tos.is_none()); + assert!(cfg.tclass_v6.is_none()); + assert!(cfg.recv_pktinfo.is_none()); + assert!(cfg.only_v6.is_none()); assert!(cfg.bind_device.is_none()); } + + #[test] + fn udp_config_with_family_builder() { + let cfg = + UdpConfig::new_with_family(SocketFamily::IPV6).with_bind("[::1]:0".parse().unwrap()); + assert_eq!(cfg.socket_family, SocketFamily::IPV6); + assert!(cfg.bind_addr.is_some()); + } } diff --git a/nex-socket/src/udp/mod.rs b/nex-socket/src/udp/mod.rs index ff56fc6..81c9a56 100644 --- a/nex-socket/src/udp/mod.rs +++ b/nex-socket/src/udp/mod.rs @@ -6,6 +6,117 @@ mod async_impl; mod config; mod sync_impl; +use std::io; + +use socket2::Socket; + +use crate::SocketFamily; + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +fn set_bool_sockopt( + socket: &Socket, + level: libc::c_int, + optname: libc::c_int, + on: bool, +) -> io::Result<()> { + use std::os::fd::AsRawFd; + let value: libc::c_int = if on { 1 } else { 0 }; + let ret = unsafe { + libc::setsockopt( + socket.as_raw_fd(), + level, + optname, + (&value as *const libc::c_int).cast(), + std::mem::size_of::() as libc::socklen_t, + ) + }; + if ret == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } +} + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +fn get_bool_sockopt(socket: &Socket, level: libc::c_int, optname: libc::c_int) -> io::Result { + use std::os::fd::AsRawFd; + let mut value: libc::c_int = 0; + let mut len = std::mem::size_of::() as libc::socklen_t; + let ret = unsafe { + libc::getsockopt( + socket.as_raw_fd(), + level, + optname, + (&mut value as *mut libc::c_int).cast(), + &mut len, + ) + }; + if ret == 0 { + Ok(value != 0) + } else { + Err(io::Error::last_os_error()) + } +} + +pub(crate) fn set_recv_pktinfo(socket: &Socket, family: SocketFamily, on: bool) -> io::Result<()> { + match family { + SocketFamily::IPV4 => set_recv_pktinfo_v4(socket, on), + SocketFamily::IPV6 => set_recv_pktinfo_v6(socket, on), + } +} + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +pub(crate) fn set_recv_pktinfo_v4(socket: &Socket, on: bool) -> io::Result<()> { + set_bool_sockopt(socket, libc::IPPROTO_IP, libc::IP_PKTINFO, on) +} + +#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] +pub(crate) fn set_recv_pktinfo_v4(_socket: &Socket, _on: bool) -> io::Result<()> { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "IP_PKTINFO is not supported on this platform", + )) +} + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +pub(crate) fn set_recv_pktinfo_v6(socket: &Socket, on: bool) -> io::Result<()> { + set_bool_sockopt(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, on) +} + +#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] +pub(crate) fn set_recv_pktinfo_v6(_socket: &Socket, _on: bool) -> io::Result<()> { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "IPV6_RECVPKTINFO is not supported on this platform", + )) +} + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +pub(crate) fn recv_pktinfo_v4(socket: &Socket) -> io::Result { + get_bool_sockopt(socket, libc::IPPROTO_IP, libc::IP_PKTINFO) +} + +#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] +pub(crate) fn recv_pktinfo_v4(_socket: &Socket) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "IP_PKTINFO is not supported on this platform", + )) +} + +#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] +pub(crate) fn recv_pktinfo_v6(socket: &Socket) -> io::Result { + get_bool_sockopt(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO) +} + +#[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] +pub(crate) fn recv_pktinfo_v6(_socket: &Socket) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "IPV6_RECVPKTINFO is not supported on this platform", + )) +} + pub use async_impl::*; pub use config::*; pub use sync_impl::*; diff --git a/nex-socket/src/udp/sync_impl.rs b/nex-socket/src/udp/sync_impl.rs index 2519c46..79ab3b0 100644 --- a/nex-socket/src/udp/sync_impl.rs +++ b/nex-socket/src/udp/sync_impl.rs @@ -1,6 +1,7 @@ use crate::udp::UdpConfig; use socket2::{Domain, Protocol, Socket, Type as SockType}; use std::io; +use std::net::IpAddr; use std::net::{SocketAddr, UdpSocket as StdUdpSocket}; /// Synchronous low level UDP socket. @@ -9,6 +10,28 @@ pub struct UdpSocket { socket: Socket, } +/// Metadata returned from `recv_msg`. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UdpRecvMeta { + /// Number of bytes received into the data buffer. + pub bytes_read: usize, + /// Source address of the datagram. + pub source_addr: SocketAddr, + /// Destination address that received the datagram, if provided by ancillary data. + pub destination_addr: Option, + /// Interface index on which the datagram was received, if provided. + pub interface_index: Option, +} + +/// Optional metadata used by `send_msg`. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct UdpSendMeta { + /// Explicit source IP address to use for transmission when supported. + pub source_addr: Option, + /// Explicit outgoing interface index when supported. + pub interface_index: Option, +} + impl UdpSocket { /// Create a socket from the provided configuration. pub fn from_config(config: &UdpConfig) -> io::Result { @@ -24,6 +47,23 @@ impl UdpSocket { if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(flag) = config.reuseport { + socket.set_reuse_port(flag)?; + } if let Some(flag) = config.broadcast { socket.set_broadcast(flag)?; } @@ -39,6 +79,38 @@ impl UdpSocket { if let Some(timeout) = config.write_timeout { socket.set_write_timeout(Some(timeout))?; } + if let Some(size) = config.recv_buffer_size { + socket.set_recv_buffer_size(size)?; + } + if let Some(size) = config.send_buffer_size { + socket.set_send_buffer_size(size)?; + } + if let Some(tos) = config.tos { + socket.set_tos(tos)?; + } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + if let Some(tclass) = config.tclass_v6 { + socket.set_tclass_v6(tclass)?; + } + if let Some(only_v6) = config.only_v6 { + socket.set_only_v6(only_v6)?; + } + if let Some(on) = config.recv_pktinfo { + crate::udp::set_recv_pktinfo(&socket, config.socket_family, on)?; + } // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] @@ -86,6 +158,157 @@ impl UdpSocket { self.socket.send_to(buf, &target.into()) } + /// Send data with ancillary metadata (`sendmsg` on Unix). + /// + /// When supported by the current OS, source address and interface index are + /// propagated using packet-info control messages. + #[cfg(unix)] + pub fn send_msg( + &self, + buf: &[u8], + target: SocketAddr, + meta: Option<&UdpSendMeta>, + ) -> io::Result { + use nix::sys::socket::{ControlMessage, MsgFlags, SockaddrIn, SockaddrIn6, sendmsg}; + use std::io::IoSlice; + use std::os::fd::AsRawFd; + + let iov = [IoSlice::new(buf)]; + let raw_fd = self.socket.as_raw_fd(); + + match target { + SocketAddr::V4(addr) => { + let sockaddr = SockaddrIn::from(addr); + #[cfg(any( + target_os = "android", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple" + ))] + { + if let Some(meta) = meta { + if meta.source_addr.is_some() || meta.interface_index.is_some() { + if let Some(src) = meta.source_addr { + if !src.is_ipv4() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "source_addr family does not match target", + )); + } + } + let mut pktinfo: libc::in_pktinfo = unsafe { std::mem::zeroed() }; + if let Some(src) = meta.source_addr.and_then(|ip| match ip { + IpAddr::V4(v4) => Some(v4), + IpAddr::V6(_) => None, + }) { + pktinfo.ipi_spec_dst.s_addr = u32::from_ne_bytes(src.octets()); + } + if let Some(ifindex) = meta.interface_index { + pktinfo.ipi_ifindex = ifindex.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "interface_index is out of range for this platform", + ) + })?; + } + let cmsgs = [ControlMessage::Ipv4PacketInfo(&pktinfo)]; + return sendmsg( + raw_fd, + &iov, + &cmsgs, + MsgFlags::empty(), + Some(&sockaddr), + ) + .map_err(|e| io::Error::from_raw_os_error(e as i32)); + } + } + } + if let Some(meta) = meta { + if meta.source_addr.is_some() || meta.interface_index.is_some() { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "send_msg packet-info metadata is not supported on this platform", + )); + } + } + sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr)) + .map_err(|e| io::Error::from_raw_os_error(e as i32)) + } + SocketAddr::V6(addr) => { + let sockaddr = SockaddrIn6::from(addr); + #[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd", + target_vendor = "apple" + ))] + { + if let Some(meta) = meta { + if meta.source_addr.is_some() || meta.interface_index.is_some() { + if let Some(src) = meta.source_addr { + if !src.is_ipv6() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "source_addr family does not match target", + )); + } + } + let mut pktinfo: libc::in6_pktinfo = unsafe { std::mem::zeroed() }; + if let Some(src) = meta.source_addr.and_then(|ip| match ip { + IpAddr::V4(_) => None, + IpAddr::V6(v6) => Some(v6), + }) { + pktinfo.ipi6_addr.s6_addr = src.octets(); + } + if let Some(ifindex) = meta.interface_index { + pktinfo.ipi6_ifindex = ifindex.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "interface_index is out of range for this platform", + ) + })?; + } + let cmsgs = [ControlMessage::Ipv6PacketInfo(&pktinfo)]; + return sendmsg( + raw_fd, + &iov, + &cmsgs, + MsgFlags::empty(), + Some(&sockaddr), + ) + .map_err(|e| io::Error::from_raw_os_error(e as i32)); + } + } + } + if let Some(meta) = meta { + if meta.source_addr.is_some() || meta.interface_index.is_some() { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "send_msg packet-info metadata is not supported on this platform", + )); + } + } + sendmsg(raw_fd, &iov, &[], MsgFlags::empty(), Some(&sockaddr)) + .map_err(|e| io::Error::from_raw_os_error(e as i32)) + } + } + } + + /// Send data with ancillary metadata (`sendmsg` is not available on this platform build). + #[cfg(not(unix))] + pub fn send_msg( + &self, + _buf: &[u8], + _target: SocketAddr, + _meta: Option<&UdpSendMeta>, + ) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "send_msg is only supported on Unix", + )) + } + /// Receive data. pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { // Safety: `MaybeUninit` has the same layout as `u8`. @@ -104,24 +327,307 @@ impl UdpSocket { Ok((n, addr)) } + /// Receive data with ancillary metadata (`recvmsg` on Unix). + /// + /// This allows extracting packet-info control messages such as destination + /// address and incoming interface index when enabled with + /// `set_recv_pktinfo_v4` / `set_recv_pktinfo_v6`. + #[cfg(unix)] + pub fn recv_msg(&self, buf: &mut [u8]) -> io::Result { + use nix::sys::socket::{ControlMessageOwned, MsgFlags, SockaddrStorage, recvmsg}; + use std::io::IoSliceMut; + use std::os::fd::AsRawFd; + + let mut iov = [IoSliceMut::new(buf)]; + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + target_vendor = "apple", + target_os = "netbsd" + ))] + let mut cmsgspace = nix::cmsg_space!(libc::in_pktinfo, libc::in6_pktinfo); + #[cfg(all( + not(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + target_vendor = "apple", + target_os = "netbsd" + )), + any(target_os = "freebsd", target_os = "openbsd") + ))] + let mut cmsgspace = nix::cmsg_space!(libc::in6_pktinfo); + #[cfg(all( + not(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + target_vendor = "apple", + target_os = "netbsd" + )), + not(any(target_os = "freebsd", target_os = "openbsd")) + ))] + let mut cmsgspace = nix::cmsg_space!(libc::c_int); + let msg = recvmsg::( + self.socket.as_raw_fd(), + &mut iov, + Some(&mut cmsgspace), + MsgFlags::empty(), + ) + .map_err(|e| io::Error::from_raw_os_error(e as i32))?; + + let source_addr = msg + .address + .and_then(|addr: SockaddrStorage| { + if let Some(v4) = addr.as_sockaddr_in() { + return Some(SocketAddr::from(*v4)); + } + if let Some(v6) = addr.as_sockaddr_in6() { + return Some(SocketAddr::from(*v6)); + } + None + }) + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid source address"))?; + + let mut destination_addr = None; + let mut interface_index = None; + + if let Ok(cmsgs) = msg.cmsgs() { + for cmsg in cmsgs { + match cmsg { + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + target_vendor = "apple", + target_os = "netbsd" + ))] + ControlMessageOwned::Ipv4PacketInfo(info) => { + destination_addr = Some(IpAddr::V4(std::net::Ipv4Addr::from( + info.ipi_addr.s_addr.to_ne_bytes(), + ))); + interface_index = Some(info.ipi_ifindex.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "received invalid interface index", + ) + })?); + } + #[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + target_os = "netbsd", + target_os = "openbsd" + ))] + ControlMessageOwned::Ipv6PacketInfo(info) => { + destination_addr = + Some(IpAddr::V6(std::net::Ipv6Addr::from(info.ipi6_addr.s6_addr))); + interface_index = Some(info.ipi6_ifindex.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "received invalid interface index", + ) + })?); + } + _ => {} + } + } + } + + Ok(UdpRecvMeta { + bytes_read: msg.bytes, + source_addr, + destination_addr, + interface_index, + }) + } + + /// Receive data with ancillary metadata (`recvmsg` is not available on this platform build). + #[cfg(not(unix))] + pub fn recv_msg(&self, _buf: &mut [u8]) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "recv_msg is only supported on Unix", + )) + } + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { self.socket.set_ttl(ttl) } + pub fn ttl(&self) -> io::Result { + self.socket.ttl() + } + pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { self.socket.set_unicast_hops_v6(hops) } + pub fn hoplimit(&self) -> io::Result { + self.socket.unicast_hops_v6() + } + + pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_address(on) + } + + pub fn reuseaddr(&self) -> io::Result { + self.socket.reuse_address() + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_reuseport(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_port(on) + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn reuseport(&self) -> io::Result { + self.socket.reuse_port() + } + + pub fn set_broadcast(&self, on: bool) -> io::Result<()> { + self.socket.set_broadcast(on) + } + + pub fn broadcast(&self) -> io::Result { + self.socket.broadcast() + } + + pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_recv_buffer_size(size) + } + + pub fn recv_buffer_size(&self) -> io::Result { + self.socket.recv_buffer_size() + } + + pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { + self.socket.set_send_buffer_size(size) + } + + pub fn send_buffer_size(&self) -> io::Result { + self.socket.send_buffer_size() + } + + pub fn set_tos(&self, tos: u32) -> io::Result<()> { + self.socket.set_tos(tos) + } + + pub fn tos(&self) -> io::Result { + self.socket.tos() + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> { + self.socket.set_tclass_v6(tclass) + } + + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos" + ))] + pub fn tclass_v6(&self) -> io::Result { + self.socket.tclass_v6() + } + + pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> { + self.socket.set_only_v6(only_v6) + } + + pub fn only_v6(&self) -> io::Result { + self.socket.only_v6() + } + pub fn set_keepalive(&self, on: bool) -> io::Result<()> { self.socket.set_keepalive(on) } + pub fn keepalive(&self) -> io::Result { + self.socket.keepalive() + } + + /// Enable IPv4 packet-info ancillary data receiving (`IP_PKTINFO`) where supported. + pub fn set_recv_pktinfo_v4(&self, on: bool) -> io::Result<()> { + crate::udp::set_recv_pktinfo_v4(&self.socket, on) + } + + /// Enable IPv6 packet-info ancillary data receiving (`IPV6_RECVPKTINFO`) where supported. + pub fn set_recv_pktinfo_v6(&self, on: bool) -> io::Result<()> { + crate::udp::set_recv_pktinfo_v6(&self.socket, on) + } + + /// Query whether IPv4 packet-info ancillary data receiving is enabled. + pub fn recv_pktinfo_v4(&self) -> io::Result { + crate::udp::recv_pktinfo_v4(&self.socket) + } + + /// Query whether IPv6 packet-info ancillary data receiving is enabled. + pub fn recv_pktinfo_v6(&self) -> io::Result { + crate::udp::recv_pktinfo_v6(&self.socket) + } + /// Retrieve the local socket address. pub fn local_addr(&self) -> io::Result { self.socket .local_addr()? .as_socket() - .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address")) + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address")) } /// Convert into a raw `std::net::UdpSocket`. @@ -129,6 +635,21 @@ impl UdpSocket { Ok(self.socket.into()) } + /// Construct from a raw `socket2::Socket`. + pub fn from_socket(socket: Socket) -> Self { + Self { socket } + } + + /// Borrow the inner `socket2::Socket`. + pub fn socket(&self) -> &Socket { + &self.socket + } + + /// Consume and return the inner `socket2::Socket`. + pub fn into_socket(self) -> Socket { + self.socket + } + #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; diff --git a/nex-sys/src/unix.rs b/nex-sys/src/unix.rs index 4aaab23..686a410 100644 --- a/nex-sys/src/unix.rs +++ b/nex-sys/src/unix.rs @@ -146,7 +146,7 @@ where } fn errno() -> i32 { - io::Error::last_os_error().raw_os_error().unwrap() + io::Error::last_os_error().raw_os_error().unwrap_or(0) } #[cfg(test)] diff --git a/nex-sys/src/windows.rs b/nex-sys/src/windows.rs index 069e610..bd5540e 100644 --- a/nex-sys/src/windows.rs +++ b/nex-sys/src/windows.rs @@ -64,5 +64,5 @@ where } fn errno() -> i32 { - std::io::Error::last_os_error().raw_os_error().unwrap() + std::io::Error::last_os_error().raw_os_error().unwrap_or(0) }