diff --git a/nex-socket/src/icmp/async_impl.rs b/nex-socket/src/icmp/async_impl.rs index 2c99567..e56bf06 100644 --- a/nex-socket/src/icmp/async_impl.rs +++ b/nex-socket/src/icmp/async_impl.rs @@ -1,4 +1,5 @@ -use crate::icmp::{IcmpConfig, IcmpKind}; +use crate::icmp::{IcmpConfig, IcmpKind, IcmpSocketType}; +use crate::SocketFamily; use socket2::{Domain, Protocol, Socket, Type as SockType}; use std::io; use std::net::{SocketAddr, UdpSocket as StdUdpSocket}; @@ -8,23 +9,23 @@ use tokio::net::UdpSocket; #[derive(Debug)] pub struct AsyncIcmpSocket { inner: UdpSocket, - sock_type: SockType, - kind: IcmpKind, + socket_type: IcmpSocketType, + socket_family: SocketFamily, } impl AsyncIcmpSocket { /// Create a new asynchronous ICMP socket. pub async fn new(config: &IcmpConfig) -> io::Result { - let (domain, proto) = match config.kind { - IcmpKind::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)), - IcmpKind::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)), + let (domain, proto) = match config.socket_family { + SocketFamily::IPV4 => (Domain::IPV4, Some(Protocol::ICMPV4)), + SocketFamily::IPV6 => (Domain::IPV6, Some(Protocol::ICMPV6)), }; // Build the socket with DGRAM preferred and RAW as a fallback - let socket = match Socket::new(domain, config.sock_type_hint, proto) { + let socket = match Socket::new(domain, config.sock_type_hint.to_sock_type(), proto) { Ok(s) => s, Err(_) => { - let alt_type = if config.sock_type_hint == SockType::DGRAM { + let alt_type = if config.sock_type_hint.is_dgram() { SockType::RAW } else { SockType::DGRAM @@ -35,29 +36,36 @@ impl AsyncIcmpSocket { socket.set_nonblocking(true)?; - // bind - if let Some(addr) = &config.bind { - socket.bind(&(*addr).into())?; - } - - // Linux: optional interface name - #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] - if let Some(interface) = &config.interface { - socket.bind_device(Some(interface.as_bytes()))?; - } - - // TTL + // Set socket options based on configuration if let Some(ttl) = config.ttl { socket.set_ttl(ttl)?; } - + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } // FreeBSD only: optional FIB support #[cfg(target_os = "freebsd")] if let Some(fib) = config.fib { socket.set_fib(fib)?; } + // Linux: optional interface name + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(interface) = &config.interface { + socket.bind_device(Some(interface.as_bytes()))?; + } - let socket_type = socket.r#type()?; + // bind to the specified address if provided + if let Some(addr) = &config.bind { + socket.bind(&(*addr).into())?; + } + + let sock_type = socket.r#type()?; // Convert socket2::Socket into std::net::UdpSocket #[cfg(windows)] @@ -73,13 +81,13 @@ impl AsyncIcmpSocket { StdUdpSocket::from_raw_fd(socket.into_raw_fd()) }; - // std → tokio::net::UdpSocket + // std -> tokio::net::UdpSocket let inner = UdpSocket::from_std(std_socket)?; Ok(Self { inner, - sock_type: socket_type, - kind: config.kind, + socket_type: IcmpSocketType::from_sock_type(sock_type), + socket_family: config.socket_family, }) } @@ -99,22 +107,31 @@ impl AsyncIcmpSocket { } /// Return the socket type (DGRAM or RAW). - pub fn sock_type(&self) -> SockType { - self.sock_type + pub fn socket_type(&self) -> IcmpSocketType { + self.socket_type } - /// Return the ICMP version. - pub fn kind(&self) -> IcmpKind { - self.kind + /// Return the socket family. + pub fn socket_family(&self) -> SocketFamily { + self.socket_family + } + + /// Return the ICMP kind. + pub fn icmp_kind(&self) -> IcmpKind { + match self.socket_family { + SocketFamily::IPV4 => IcmpKind::V4, + SocketFamily::IPV6 => IcmpKind::V6, + } } - /// Access the native socket for low level operations. + /// Extract the RAW file descriptor for Unix. #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; self.inner.as_raw_fd() } + /// Extract the RAW socket handle for Windows. #[cfg(windows)] pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { use std::os::windows::io::AsRawSocket; diff --git a/nex-socket/src/icmp/config.rs b/nex-socket/src/icmp/config.rs index 9e7d873..3583f29 100644 --- a/nex-socket/src/icmp/config.rs +++ b/nex-socket/src/icmp/config.rs @@ -1,5 +1,7 @@ use socket2::Type as SockType; -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; + +use crate::SocketFamily; /// ICMP protocol version. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -8,49 +10,127 @@ pub enum IcmpKind { V6, } +/// ICMP socket type, either DGRAM or RAW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IcmpSocketType { + Dgram, + Raw, +} + +impl IcmpSocketType { + /// Returns true if the socket type is DGRAM. + pub fn is_dgram(&self) -> bool { + matches!(self, IcmpSocketType::Dgram) + } + + /// Returns true if the socket type is RAW. + pub fn is_raw(&self) -> bool { + matches!(self, IcmpSocketType::Raw) + } + + /// Converts the ICMP socket type from a `socket2::Type`. + pub(crate) fn from_sock_type(sock_type: SockType) -> Self { + match sock_type { + SockType::DGRAM => IcmpSocketType::Dgram, + SockType::RAW => IcmpSocketType::Raw, + _ => panic!("Invalid ICMP socket type"), + } + } + + /// Converts the ICMP socket type to a `socket2::Type`. + pub(crate) fn to_sock_type(&self) -> SockType { + match self { + IcmpSocketType::Dgram => SockType::DGRAM, + IcmpSocketType::Raw => SockType::RAW, + } + } +} + /// Configuration for an ICMP socket. #[derive(Debug, Clone)] pub struct IcmpConfig { - pub kind: IcmpKind, + /// The socket family. + pub socket_family: SocketFamily, + /// Optional bind address for the socket. pub bind: Option, + /// Time-to-live for IPv4 packets. pub ttl: Option, + /// Hop limit for IPv6 packets. + pub hoplimit: Option, + /// Read timeout for the socket. + pub read_timeout: Option, + /// Write timeout for the socket. + pub write_timeout: Option, + /// Network interface to use for the socket. pub interface: Option, - pub sock_type_hint: SockType, + /// Socket type hint, DGRAM preferred on Linux, RAW fallback on macOS/Windows. + pub sock_type_hint: IcmpSocketType, + /// FreeBSD only: optional FIB (Forwarding Information Base) support. pub fib: Option, } impl IcmpConfig { + /// Creates a new ICMP configuration with the specified kind. pub fn new(kind: IcmpKind) -> Self { Self { - kind, + socket_family: match kind { + IcmpKind::V4 => SocketFamily::IPV4, + IcmpKind::V6 => SocketFamily::IPV6, + }, bind: None, ttl: None, + hoplimit: None, + read_timeout: None, + write_timeout: None, interface: None, - sock_type_hint: SockType::DGRAM, // DGRAM preferred on Linux, RAW fallback on macOS/Windows - fib: None, // FreeBSD only + sock_type_hint: IcmpSocketType::Dgram, + fib: None, } } + /// Set bind address for the socket. pub fn with_bind(mut self, addr: SocketAddr) -> Self { self.bind = Some(addr); self } + /// Set the time-to-live for IPv4 packets. pub fn with_ttl(mut self, ttl: u32) -> Self { self.ttl = Some(ttl); self } + /// Set the hop limit for IPv6 packets. + pub fn with_hoplimit(mut self, hops: u32) -> Self { + self.hoplimit = Some(hops); + self + } + + /// Set the read timeout for the socket. + pub fn with_read_timeout(mut self, timeout: Duration) -> Self { + self.read_timeout = Some(timeout); + self + } + + /// Set the write timeout for the socket. + pub fn with_write_timeout(mut self, timeout: Duration) -> Self { + self.write_timeout = Some(timeout); + self + } + + /// Set the network interface to use for the socket. pub fn with_interface(mut self, iface: impl Into) -> Self { self.interface = Some(iface.into()); self } - pub fn with_sock_type(mut self, ty: SockType) -> Self { + /// Set the socket type hint. (DGRAM or RAW) + pub fn with_sock_type(mut self, ty: IcmpSocketType) -> Self { self.sock_type_hint = ty; self } + /// Set the FIB (Forwarding Information Base) for FreeBSD. pub fn with_fib(mut self, fib: u32) -> Self { self.fib = Some(fib); self @@ -60,7 +140,6 @@ impl IcmpConfig { #[cfg(test)] mod tests { use super::*; - use socket2::Type; #[test] fn icmp_config_builders() { let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); @@ -68,11 +147,11 @@ mod tests { .with_bind(addr) .with_ttl(4) .with_interface("eth0") - .with_sock_type(Type::RAW); - assert_eq!(cfg.kind, IcmpKind::V4); + .with_sock_type(IcmpSocketType::Raw); + assert_eq!(cfg.socket_family, SocketFamily::IPV4); assert_eq!(cfg.bind, Some(addr)); assert_eq!(cfg.ttl, Some(4)); assert_eq!(cfg.interface.as_deref(), Some("eth0")); - assert_eq!(cfg.sock_type_hint, Type::RAW); + assert_eq!(cfg.sock_type_hint, IcmpSocketType::Raw); } } diff --git a/nex-socket/src/icmp/sync_impl.rs b/nex-socket/src/icmp/sync_impl.rs index a9c28ba..0c028c0 100644 --- a/nex-socket/src/icmp/sync_impl.rs +++ b/nex-socket/src/icmp/sync_impl.rs @@ -1,4 +1,5 @@ -use crate::icmp::{IcmpConfig, IcmpKind}; +use crate::icmp::{IcmpConfig, IcmpKind, IcmpSocketType}; +use crate::SocketFamily; use socket2::{Domain, Protocol, Socket, Type as SockType}; use std::io; use std::net::{SocketAddr, UdpSocket}; @@ -7,22 +8,22 @@ use std::net::{SocketAddr, UdpSocket}; #[derive(Debug)] pub struct IcmpSocket { inner: UdpSocket, - sock_type: SockType, - kind: IcmpKind, + socket_type: IcmpSocketType, + socket_family: SocketFamily, } impl IcmpSocket { /// Create a new synchronous ICMP socket. pub fn new(config: &IcmpConfig) -> io::Result { - let (domain, proto) = match config.kind { - IcmpKind::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)), - IcmpKind::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)), + let (domain, proto) = match config.socket_family { + SocketFamily::IPV4 => (Domain::IPV4, Some(Protocol::ICMPV4)), + SocketFamily::IPV6 => (Domain::IPV6, Some(Protocol::ICMPV6)), }; - let socket = match Socket::new(domain, config.sock_type_hint, proto) { + let socket = match Socket::new(domain, config.sock_type_hint.to_sock_type(), proto) { Ok(s) => s, Err(_) => { - let alt_type = if config.sock_type_hint == SockType::DGRAM { + let alt_type = if config.sock_type_hint.is_dgram() { SockType::RAW } else { SockType::DGRAM @@ -31,33 +32,46 @@ impl IcmpSocket { } }; - socket.set_nonblocking(false)?; // blocking mode for sync usage + socket.set_nonblocking(false)?; - if let Some(addr) = &config.bind { - socket.bind(&(*addr).into())?; + // Set socket options based on configuration + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; } - + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } + // FreeBSD only: optional FIB support + #[cfg(target_os = "freebsd")] + if let Some(fib) = config.fib { + socket.set_fib(fib)?; + } + // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] if let Some(interface) = &config.interface { socket.bind_device(Some(interface.as_bytes()))?; } - if let Some(ttl) = config.ttl { - socket.set_ttl(ttl)?; + // bind to the specified address if provided + if let Some(addr) = &config.bind { + socket.bind(&(*addr).into())?; } - #[cfg(target_os = "freebsd")] - if let Some(fib) = config.fib { - socket.set_fib(fib)?; - } + let sock_type = socket.r#type()?; // Convert socket2::Socket into std::net::UdpSocket let std_socket: UdpSocket = socket.into(); Ok(Self { inner: std_socket, - sock_type: config.sock_type_hint, - kind: config.kind, + socket_type: IcmpSocketType::from_sock_type(sock_type), + socket_family: config.socket_family, }) } @@ -77,22 +91,31 @@ impl IcmpSocket { } /// Return the socket type. - pub fn sock_type(&self) -> SockType { - self.sock_type + pub fn socket_type(&self) -> IcmpSocketType { + self.socket_type + } + + /// Return the socket family. + pub fn socket_family(&self) -> SocketFamily { + self.socket_family } /// Return the ICMP variant. - pub fn kind(&self) -> IcmpKind { - self.kind + pub fn icmp_kind(&self) -> IcmpKind { + match self.socket_family { + SocketFamily::IPV4 => IcmpKind::V4, + SocketFamily::IPV6 => IcmpKind::V6, + } } - /// Access the underlying socket. + /// Extract the RAW file descriptor for Unix. #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; self.inner.as_raw_fd() } + /// Extract the RAW socket handle for Windows. #[cfg(windows)] pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { use std::os::windows::io::AsRawSocket; diff --git a/nex-socket/src/lib.rs b/nex-socket/src/lib.rs index bad1362..4849a8b 100644 --- a/nex-socket/src/lib.rs +++ b/nex-socket/src/lib.rs @@ -7,3 +7,48 @@ pub mod icmp; pub mod tcp; pub mod udp; + +use std::net::{IpAddr, SocketAddr}; + +/// Represents the socket address family (IPv4 or IPv6) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SocketFamily { + IPV4, + IPV6, +} + +impl SocketFamily { + /// Returns the socket family of the IP address. + pub fn from_ip(ip: &IpAddr) -> Self { + match ip { + IpAddr::V4(_) => SocketFamily::IPV4, + IpAddr::V6(_) => SocketFamily::IPV6, + } + } + + /// Returns the socket family of the socket address. + pub fn from_socket_addr(addr: &SocketAddr) -> Self { + match addr { + SocketAddr::V4(_) => SocketFamily::IPV4, + SocketAddr::V6(_) => SocketFamily::IPV6, + } + } + + /// Returns true if the socket family is IPv4. + pub fn is_v4(&self) -> bool { + matches!(self, SocketFamily::IPV4) + } + + /// Returns true if the socket family is IPv6. + pub fn is_v6(&self) -> bool { + matches!(self, SocketFamily::IPV6) + } + + /// Converts the socket family to a `socket2::Domain`. + pub(crate) fn to_domain(&self) -> socket2::Domain { + match self { + SocketFamily::IPV4 => socket2::Domain::IPV4, + SocketFamily::IPV6 => socket2::Domain::IPV6, + } + } +} diff --git a/nex-socket/src/tcp/async_impl.rs b/nex-socket/src/tcp/async_impl.rs index 42a073a..41c489a 100644 --- a/nex-socket/src/tcp/async_impl.rs +++ b/nex-socket/src/tcp/async_impl.rs @@ -14,8 +14,11 @@ pub struct AsyncTcpSocket { impl AsyncTcpSocket { /// Create a socket from the given configuration without connecting. pub fn from_config(config: &TcpConfig) -> io::Result { - let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?; + let socket = Socket::new(config.socket_family.to_domain(), config.socket_type.to_sock_type(), Some(Protocol::TCP))?; + socket.set_nonblocking(true)?; + + // Set socket options based on configuration if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } @@ -25,18 +28,30 @@ impl AsyncTcpSocket { if let Some(ttl) = config.ttl { socket.set_ttl(ttl)?; } + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(keepalive) = config.keepalive { + socket.set_keepalive(keepalive)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } + // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] if let Some(iface) = &config.bind_device { socket.bind_device(Some(iface.as_bytes()))?; } + // bind to the specified address if provided if let Some(addr) = config.bind_addr { socket.bind(&addr.into())?; } - socket.set_nonblocking(true)?; - Ok(Self { socket }) } @@ -145,24 +160,42 @@ impl AsyncTcpSocket { Ok((n, addr)) } - // --- option helpers --- + /// Shutdown the socket. + pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> { + self.socket.shutdown(how) + } + /// Set reuse address option. pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> { self.socket.set_reuse_address(on) } + /// Set no delay option for TCP. pub fn set_nodelay(&self, on: bool) -> io::Result<()> { self.socket.set_nodelay(on) } + /// Set linger option for the socket. pub fn set_linger(&self, dur: Option) -> io::Result<()> { self.socket.set_linger(dur) } + /// Set the time-to-live for IPv4 packets. pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { self.socket.set_ttl(ttl) } + /// Set the hop limit for IPv6 packets. + pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { + self.socket.set_unicast_hops_v6(hops) + } + + /// Set the keepalive option for the socket. + pub fn set_keepalive(&self, on: bool) -> io::Result<()> { + self.socket.set_keepalive(on) + } + + /// Set the bind device for the socket (Linux specific). pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] return self.socket.bind_device(Some(iface.as_bytes())); @@ -191,12 +224,14 @@ impl AsyncTcpSocket { TcpStream::from_std(std_stream) } + /// Extract the RAW file descriptor for Unix. #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; self.socket.as_raw_fd() } + /// Extract the RAW socket handle for Windows. #[cfg(windows)] pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { use std::os::windows::io::AsRawSocket; diff --git a/nex-socket/src/tcp/config.rs b/nex-socket/src/tcp/config.rs index 8946994..a836709 100644 --- a/nex-socket/src/tcp/config.rs +++ b/nex-socket/src/tcp/config.rs @@ -1,42 +1,92 @@ -use socket2::{Domain, Type as SockType}; +use socket2::Type as SockType; use std::net::SocketAddr; use std::time::Duration; +use crate::SocketFamily; + +/// TCP socket type, either STREAM or RAW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TcpSocketType { + Stream, + Raw, +} + +impl TcpSocketType { + /// Returns true if the socket type is STREAM. + pub fn is_stream(&self) -> bool { + matches!(self, TcpSocketType::Stream) + } + + /// Returns true if the socket type is RAW. + pub fn is_raw(&self) -> bool { + matches!(self, TcpSocketType::Raw) + } + + /// Converts the TCP socket type to a `socket2::Type`. + pub(crate) fn to_sock_type(&self) -> SockType { + match self { + TcpSocketType::Stream => SockType::STREAM, + TcpSocketType::Raw => SockType::RAW, + } + } +} + /// Configuration options for a TCP socket. #[derive(Debug, Clone)] pub struct TcpConfig { - pub domain: Domain, - pub sock_type: SockType, + /// The socket family, either IPv4 or IPv6. + pub socket_family: SocketFamily, + /// The type of TCP socket, either STREAM or RAW. + pub socket_type: TcpSocketType, + /// Optional address to bind the socket to. pub bind_addr: Option, + /// Whether the socket should be non-blocking. pub nonblocking: bool, + /// Whether to allow address reuse. pub reuseaddr: Option, + /// Whether to disable Nagle's algorithm (TCP_NODELAY). pub nodelay: Option, + /// Optional linger duration for the socket. pub linger: Option, + /// Optional Time-To-Live (TTL) for the socket. pub ttl: Option, + /// Optional Hop Limit for the socket (IPv6). + pub hoplimit: Option, + /// Optional read timeout for the socket. + pub read_timeout: Option, + /// Optional write timeout for the socket. + pub write_timeout: Option, + /// Optional device to bind the socket to. pub bind_device: Option, + /// Whether to enable TCP keepalive. + pub keepalive: Option, } impl TcpConfig { /// Create a STREAM socket for IPv4. pub fn v4_stream() -> Self { Self { - domain: Domain::IPV4, - sock_type: SockType::STREAM, + socket_family: SocketFamily::IPV4, + socket_type: TcpSocketType::Stream, bind_addr: None, nonblocking: false, reuseaddr: None, nodelay: None, linger: None, ttl: None, + hoplimit: None, + read_timeout: None, + write_timeout: None, bind_device: None, + keepalive: None, } } /// Create a RAW socket. Requires administrator privileges. pub fn raw_v4() -> Self { Self { - domain: Domain::IPV4, - sock_type: SockType::RAW, + socket_family: SocketFamily::IPV4, + socket_type: TcpSocketType::Raw, ..Self::v4_stream() } } @@ -44,8 +94,8 @@ impl TcpConfig { /// Create a STREAM socket for IPv6. pub fn v6_stream() -> Self { Self { - domain: Domain::IPV6, - sock_type: SockType::STREAM, + socket_family: SocketFamily::IPV6, + socket_type: TcpSocketType::Stream, ..Self::v4_stream() } } @@ -53,8 +103,8 @@ impl TcpConfig { /// Create a RAW socket for IPv6. Requires administrator privileges. pub fn raw_v6() -> Self { Self { - domain: Domain::IPV6, - sock_type: SockType::RAW, + socket_family: SocketFamily::IPV6, + socket_type: TcpSocketType::Raw, ..Self::v4_stream() } } @@ -91,6 +141,26 @@ impl TcpConfig { self } + pub fn with_hoplimit(mut self, hops: u32) -> Self { + self.hoplimit = Some(hops); + self + } + + pub fn with_keepalive(mut self, on: bool) -> Self { + self.keepalive = Some(on); + self + } + + pub fn with_read_timeout(mut self, timeout: Duration) -> Self { + self.read_timeout = Some(timeout); + self + } + + pub fn with_write_timeout(mut self, timeout: Duration) -> Self { + self.write_timeout = Some(timeout); + self + } + pub fn with_bind_device(mut self, iface: impl Into) -> Self { self.bind_device = Some(iface.into()); self @@ -111,8 +181,8 @@ mod tests { .with_nodelay(true) .with_ttl(10); - assert_eq!(cfg.domain, Domain::IPV4); - assert_eq!(cfg.sock_type, SockType::STREAM); + assert_eq!(cfg.socket_family, SocketFamily::IPV4); + assert_eq!(cfg.socket_type, TcpSocketType::Stream); assert_eq!(cfg.bind_addr, Some(addr)); assert!(cfg.nonblocking); assert_eq!(cfg.reuseaddr, Some(true)); diff --git a/nex-socket/src/tcp/sync_impl.rs b/nex-socket/src/tcp/sync_impl.rs index 725c8b1..7a89263 100644 --- a/nex-socket/src/tcp/sync_impl.rs +++ b/nex-socket/src/tcp/sync_impl.rs @@ -20,9 +20,11 @@ pub struct TcpSocket { impl TcpSocket { /// Build a socket according to `TcpSocketConfig`. pub fn from_config(config: &TcpConfig) -> io::Result { - let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?; + let socket = Socket::new(config.socket_family.to_domain(), config.socket_type.to_sock_type(), Some(Protocol::TCP))?; - // Apply all configuration options + socket.set_nonblocking(config.nonblocking)?; + + // Set socket options based on configuration if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } @@ -35,20 +37,30 @@ impl TcpSocket { if let Some(ttl) = config.ttl { socket.set_ttl(ttl)?; } + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(keepalive) = config.keepalive { + socket.set_keepalive(keepalive)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } + // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] if let Some(iface) = &config.bind_device { socket.bind_device(Some(iface.as_bytes()))?; } - // Bind to the specified address if provided + // bind to the specified address if provided if let Some(addr) = config.bind_addr { socket.bind(&addr.into())?; } - // Set non blocking mode - socket.set_nonblocking(config.nonblocking)?; - Ok(Self { socket }) } @@ -79,16 +91,17 @@ impl TcpSocket { Self::new(Domain::IPV6, SockType::RAW) } - // --- socket operations --- - + /// Bind the socket to a specific address. pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { self.socket.bind(&addr.into()) } + /// Connect to a remote address. pub fn connect(&self, addr: SocketAddr) -> io::Result<()> { self.socket.connect(&addr.into()) } + /// Connect to the target address with a timeout. #[cfg(unix)] pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result { let raw_fd = self.socket.as_raw_fd(); @@ -198,19 +211,23 @@ impl TcpSocket { Ok(std_stream) } + /// Start listening for incoming connections. pub fn listen(&self, backlog: i32) -> io::Result<()> { self.socket.listen(backlog) } + /// Accept an incoming connection. pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { let (stream, addr) = self.socket.accept()?; Ok((stream.into(), addr.as_socket().unwrap())) } + /// Convert the socket into a `TcpStream`. pub fn to_tcp_stream(self) -> io::Result { Ok(self.socket.into()) } + /// Convert the socket into a `TcpListener`. pub fn to_tcp_listener(self) -> io::Result { Ok(self.socket.into()) } @@ -238,24 +255,42 @@ impl TcpSocket { Ok((n, addr)) } - // --- option helpers --- + /// Shutdown the socket. + pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> { + self.socket.shutdown(how) + } + /// Set the socket to reuse the address. pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> { self.socket.set_reuse_address(on) } + /// Set the socket to not delay packets. pub fn set_nodelay(&self, on: bool) -> io::Result<()> { self.socket.set_nodelay(on) } + /// Set the linger option for the socket. pub fn set_linger(&self, dur: Option) -> io::Result<()> { self.socket.set_linger(dur) } + /// Set the time-to-live for IPv4 packets. pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { self.socket.set_ttl(ttl) } + /// Set the hop limit for IPv6 packets. + pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { + self.socket.set_unicast_hops_v6(hops) + } + + /// Set the keepalive option for the socket. + pub fn set_keepalive(&self, on: bool) -> io::Result<()> { + self.socket.set_keepalive(on) + } + + /// Set the bind device for the socket (Linux specific). pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] return self.socket.bind_device(Some(iface.as_bytes())); @@ -270,8 +305,7 @@ impl TcpSocket { } } - // --- information helpers --- - + /// Retrieve the local address of the socket. pub fn local_addr(&self) -> io::Result { self.socket .local_addr()? @@ -279,12 +313,14 @@ impl TcpSocket { .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address")) } + /// Extract the RAW file descriptor for Unix. #[cfg(unix)] pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { use std::os::fd::AsRawFd; self.socket.as_raw_fd() } + /// Extract the RAW socket handle for Windows. #[cfg(windows)] pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { use std::os::windows::io::AsRawSocket; diff --git a/nex-socket/src/udp/async_impl.rs b/nex-socket/src/udp/async_impl.rs index 29543a8..f9e9c51 100644 --- a/nex-socket/src/udp/async_impl.rs +++ b/nex-socket/src/udp/async_impl.rs @@ -13,38 +13,41 @@ pub struct AsyncUdpSocket { impl AsyncUdpSocket { /// Create an asynchronous UDP socket from the given configuration. pub fn from_config(config: &UdpConfig) -> io::Result { - // Determine address family from the bind address - let domain = match config.bind_addr { - Some(SocketAddr::V4(_)) => Domain::IPV4, - Some(SocketAddr::V6(_)) => Domain::IPV6, - None => Domain::IPV4, // default - }; - - let socket = Socket::new(domain, SockType::DGRAM, Some(Protocol::UDP))?; + let socket = Socket::new(config.socket_family.to_domain(), config.socket_type.to_sock_type(), Some(Protocol::UDP))?; + socket.set_nonblocking(true)?; + + // Set socket options based on configuration if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } - if let Some(flag) = config.broadcast { socket.set_broadcast(flag)?; } - if let Some(ttl) = config.ttl { socket.set_ttl(ttl)?; } + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } + // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] if let Some(iface) = &config.bind_device { socket.bind_device(Some(iface.as_bytes()))?; } + // bind to the specified address if provided if let Some(addr) = config.bind_addr { socket.bind(&addr.into())?; } - socket.set_nonblocking(true)?; - #[cfg(windows)] let std_socket = unsafe { use std::os::windows::io::{FromRawSocket, IntoRawSocket}; diff --git a/nex-socket/src/udp/config.rs b/nex-socket/src/udp/config.rs index ce1a008..f1e4146 100644 --- a/nex-socket/src/udp/config.rs +++ b/nex-socket/src/udp/config.rs @@ -1,20 +1,57 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; + +use socket2::Type as SockType; + +use crate::SocketFamily; + +/// UDP socket type, either DGRAM or RAW. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UdpSocketType { + Dgram, + Raw, +} + +impl UdpSocketType { + /// Returns true if the socket type is DGRAM. + pub fn is_dgram(&self) -> bool { + matches!(self, UdpSocketType::Dgram) + } + + /// Returns true if the socket type is RAW. + pub fn is_raw(&self) -> bool { + matches!(self, UdpSocketType::Raw) + } + + /// Converts the UDP socket type to a `socket2::Type`. + pub(crate) fn to_sock_type(&self) -> SockType { + match self { + UdpSocketType::Dgram => SockType::DGRAM, + UdpSocketType::Raw => SockType::RAW, + } + } +} /// Configuration options for a UDP socket. #[derive(Debug, Clone)] pub struct UdpConfig { + /// The socket family. + pub socket_family: SocketFamily, + /// The socket type (DGRAM or RAW). + pub socket_type: UdpSocketType, /// Address to bind. If `None`, the operating system chooses the address. pub bind_addr: Option, - /// Enable address reuse (`SO_REUSEADDR`). pub reuseaddr: Option, - /// Allow broadcast (`SO_BROADCAST`). pub broadcast: Option, - /// Time to live value. pub ttl: Option, - + /// Hop limit value. + pub hoplimit: Option, + /// Read timeout for the socket. + pub read_timeout: Option, + /// Write timeout for the socket. + pub write_timeout: Option, /// Bind to a specific interface (Linux only). pub bind_device: Option, } @@ -22,15 +59,75 @@ pub struct UdpConfig { impl Default for UdpConfig { fn default() -> Self { Self { + socket_family: SocketFamily::IPV4, + socket_type: UdpSocketType::Dgram, bind_addr: None, reuseaddr: None, broadcast: None, ttl: None, + hoplimit: None, + read_timeout: None, + write_timeout: None, bind_device: None, } } } +impl UdpConfig { + /// Create a new UDP configuration with default values. + pub fn new() -> Self { + Self::default() + } + + /// Set the bind address. + pub fn with_bind_addr(mut self, addr: SocketAddr) -> Self { + self.bind_addr = Some(addr); + self + } + + /// Enable address reuse. + pub fn with_reuseaddr(mut self, on: bool) -> Self { + self.reuseaddr = Some(on); + self + } + + /// Allow broadcast. + pub fn with_broadcast(mut self, on: bool) -> Self { + self.broadcast = Some(on); + self + } + + /// Set the time to live value. + pub fn with_ttl(mut self, ttl: u32) -> Self { + self.ttl = Some(ttl); + self + } + + /// Set the hop limit value. + pub fn with_hoplimit(mut self, hops: u32) -> Self { + self.hoplimit = Some(hops); + self + } + + /// Set the read timeout. + pub fn with_read_timeout(mut self, timeout: Duration) -> Self { + self.read_timeout = Some(timeout); + self + } + + /// Set the write timeout. + pub fn with_write_timeout(mut self, timeout: Duration) -> Self { + self.write_timeout = Some(timeout); + self + } + + /// Bind to a specific interface (Linux only). + pub fn with_bind_device(mut self, iface: impl Into) -> Self { + self.bind_device = Some(iface.into()); + self + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/nex-socket/src/udp/sync_impl.rs b/nex-socket/src/udp/sync_impl.rs index aa6ea0a..911b5c2 100644 --- a/nex-socket/src/udp/sync_impl.rs +++ b/nex-socket/src/udp/sync_impl.rs @@ -12,37 +12,41 @@ pub struct UdpSocket { impl UdpSocket { /// Create a socket from the provided configuration. pub fn from_config(config: &UdpConfig) -> io::Result { - // Determine address family from the bind address - let domain = match config.bind_addr { - Some(SocketAddr::V4(_)) => Domain::IPV4, - Some(SocketAddr::V6(_)) => Domain::IPV6, - None => Domain::IPV4, // default - }; + let socket = Socket::new(config.socket_family.to_domain(), config.socket_type.to_sock_type(), Some(Protocol::UDP))?; - let socket = Socket::new(domain, SockType::DGRAM, Some(Protocol::UDP))?; + socket.set_nonblocking(false)?; + // Set socket options based on configuration if let Some(flag) = config.reuseaddr { socket.set_reuse_address(flag)?; } - if let Some(flag) = config.broadcast { socket.set_broadcast(flag)?; } - if let Some(ttl) = config.ttl { socket.set_ttl(ttl)?; } + if let Some(hoplimit) = config.hoplimit { + socket.set_unicast_hops_v6(hoplimit)?; + } + if let Some(timeout) = config.read_timeout { + socket.set_read_timeout(Some(timeout))?; + } + if let Some(timeout) = config.write_timeout { + socket.set_write_timeout(Some(timeout))?; + } + // Linux: optional interface name #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] if let Some(iface) = &config.bind_device { socket.bind_device(Some(iface.as_bytes()))?; } + // bind to the specified address if provided if let Some(addr) = config.bind_addr { socket.bind(&addr.into())?; } - socket.set_nonblocking(false)?; // blocking mode for sync usage Ok(Self { socket }) } @@ -96,6 +100,18 @@ impl UdpSocket { Ok((n, addr)) } + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.socket.set_ttl(ttl) + } + + pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> { + self.socket.set_unicast_hops_v6(hops) + } + + pub fn set_keepalive(&self, on: bool) -> io::Result<()> { + self.socket.set_keepalive(on) + } + /// Retrieve the local socket address. pub fn local_addr(&self) -> io::Result { self.socket