From e323484221e08d59a9fc06e2dd6881a054a5f9a6 Mon Sep 17 00:00:00 2001 From: shellrow Date: Sun, 6 Jul 2025 14:40:35 +0900 Subject: [PATCH] v0.20.0: full refactor --- .gitattributes | 1 + .gitignore | 11 + Cargo.toml | 34 +- Cross.toml | 9 - LICENSE | 2 +- README.md | 16 +- examples/arp.rs | 165 +- examples/async_icmp_socket.rs | 91 + examples/async_tcp_connect.rs | 62 - examples/async_tcp_socket.rs | 32 + examples/async_tcp_stream.rs | 45 - examples/dump.rs | 270 ++- examples/icmp_ping.rs | 252 +-- examples/icmp_socket.rs | 59 + examples/list_interfaces.rs | 47 - examples/ndp.rs | 218 +- examples/parse_frame.rs | 64 +- examples/serialize.rs | 77 - examples/tcp_ping.rs | 194 +- examples/tcp_socket.rs | 29 + examples/tcp_stream.rs | 47 - examples/udp_ping.rs | 260 +-- examples/udp_socket.rs | 30 + nex-core/src/bitfield.rs | 387 ++++ nex-core/src/ip.rs | 118 +- nex-core/src/lib.rs | 1 + nex-datalink/Cargo.toml | 1 + nex-datalink/src/bindings/bpf.rs | 2 +- nex-datalink/src/lib.rs | 18 + nex-datalink/src/linux.rs | 37 +- nex-macro-helper/Cargo.toml | 14 - nex-macro-helper/src/lib.rs | 9 - nex-macro-helper/src/packet.rs | 231 -- nex-macro-helper/src/types.rs | 477 ---- nex-macro/Cargo.toml | 23 - nex-macro/src/decorator.rs | 1740 -------------- nex-macro/src/lib.rs | 46 - nex-macro/src/util.rs | 1062 --------- nex-packet-builder/Cargo.toml | 16 - nex-packet-builder/src/arp.rs | 67 - nex-packet-builder/src/builder.rs | 174 -- nex-packet-builder/src/dhcp.rs | 163 -- nex-packet-builder/src/ethernet.rs | 95 - nex-packet-builder/src/icmp.rs | 61 - nex-packet-builder/src/icmpv6.rs | 62 - nex-packet-builder/src/ipv4.rs | 119 - nex-packet-builder/src/ipv6.rs | 92 - nex-packet-builder/src/ndp.rs | 63 - nex-packet-builder/src/tcp.rs | 181 -- nex-packet-builder/src/udp.rs | 113 - nex-packet-builder/src/util.rs | 393 ---- nex-packet/Cargo.toml | 8 +- nex-packet/src/arp.rs | 461 ++-- nex-packet/src/builder/arp.rs | 110 + nex-packet/src/builder/dhcp.rs | 67 + nex-packet/src/builder/ethernet.rs | 64 + nex-packet/src/builder/icmp.rs | 83 + nex-packet/src/builder/icmpv6.rs | 82 + nex-packet/src/builder/ipv4.rs | 122 + nex-packet/src/builder/ipv6.rs | 119 + .../lib.rs => nex-packet/src/builder/mod.rs | 10 +- nex-packet/src/builder/ndp.rs | 84 + nex-packet/src/builder/tcp.rs | 126 ++ nex-packet/src/builder/udp.rs | 96 + nex-packet/src/dhcp.rs | 510 +++-- nex-packet/src/dns.rs | 1675 ++++++++------ nex-packet/src/ethernet.rs | 388 ++-- nex-packet/src/flowcontrol.rs | 133 ++ nex-packet/src/frame.rs | 315 +-- nex-packet/src/gre.rs | 345 ++- nex-packet/src/icmp.rs | 681 +++--- nex-packet/src/icmpv6.rs | 1991 +++++++++++++---- nex-packet/src/ip.rs | 313 ++- nex-packet/src/ipv4.rs | 764 ++++--- nex-packet/src/ipv6.rs | 776 ++++--- nex-packet/src/lib.rs | 40 +- nex-packet/src/packet.rs | 50 + nex-packet/src/sll.rs | 25 - nex-packet/src/sll2.rs | 37 - nex-packet/src/tcp.rs | 935 ++++---- nex-packet/src/udp.rs | 265 +-- nex-packet/src/usbpcap.rs | 172 -- nex-packet/src/util.rs | 9 +- nex-packet/src/vlan.rs | 247 +- nex-packet/src/vxlan.rs | 121 + nex-socket/Cargo.toml | 11 +- nex-socket/src/icmp/async_impl.rs | 123 + nex-socket/src/icmp/config.rs | 79 + nex-socket/src/icmp/mod.rs | 7 + nex-socket/src/icmp/sync_impl.rs | 101 + nex-socket/src/lib.rs | 17 +- nex-socket/src/socket/async_impl.rs | 743 ------ nex-socket/src/socket/mod.rs | 120 - nex-socket/src/socket/sync_impl.rs | 389 ---- nex-socket/src/sys/mod.rs | 9 - nex-socket/src/sys/unix.rs | 108 - nex-socket/src/sys/windows.rs | 285 --- nex-socket/src/tcp/async_impl.rs | 191 ++ nex-socket/src/tcp/config.rs | 122 + nex-socket/src/tcp/mod.rs | 7 + nex-socket/src/tcp/sync_impl.rs | 280 +++ nex-socket/src/udp/async_impl.rs | 115 + nex-socket/src/udp/config.rs | 47 + nex-socket/src/udp/mod.rs | 7 + nex-socket/src/udp/sync_impl.rs | 135 ++ nex-sys/src/lib.rs | 2 +- nex-sys/src/unix.rs | 26 + nex/Cargo.toml | 35 +- nex/src/lib.rs | 12 +- scripts/build-all.ps1 | 20 + scripts/build-all.sh | 25 + 111 files changed, 10126 insertions(+), 11894 deletions(-) create mode 100644 .gitattributes delete mode 100644 Cross.toml create mode 100644 examples/async_icmp_socket.rs delete mode 100644 examples/async_tcp_connect.rs create mode 100644 examples/async_tcp_socket.rs delete mode 100644 examples/async_tcp_stream.rs create mode 100644 examples/icmp_socket.rs delete mode 100644 examples/list_interfaces.rs delete mode 100644 examples/serialize.rs create mode 100644 examples/tcp_socket.rs delete mode 100644 examples/tcp_stream.rs create mode 100644 examples/udp_socket.rs create mode 100644 nex-core/src/bitfield.rs delete mode 100644 nex-macro-helper/Cargo.toml delete mode 100644 nex-macro-helper/src/lib.rs delete mode 100644 nex-macro-helper/src/packet.rs delete mode 100644 nex-macro-helper/src/types.rs delete mode 100644 nex-macro/Cargo.toml delete mode 100644 nex-macro/src/decorator.rs delete mode 100644 nex-macro/src/lib.rs delete mode 100644 nex-macro/src/util.rs delete mode 100644 nex-packet-builder/Cargo.toml delete mode 100644 nex-packet-builder/src/arp.rs delete mode 100644 nex-packet-builder/src/builder.rs delete mode 100644 nex-packet-builder/src/dhcp.rs delete mode 100644 nex-packet-builder/src/ethernet.rs delete mode 100644 nex-packet-builder/src/icmp.rs delete mode 100644 nex-packet-builder/src/icmpv6.rs delete mode 100644 nex-packet-builder/src/ipv4.rs delete mode 100644 nex-packet-builder/src/ipv6.rs delete mode 100644 nex-packet-builder/src/ndp.rs delete mode 100644 nex-packet-builder/src/tcp.rs delete mode 100644 nex-packet-builder/src/udp.rs delete mode 100644 nex-packet-builder/src/util.rs create mode 100644 nex-packet/src/builder/arp.rs create mode 100644 nex-packet/src/builder/dhcp.rs create mode 100644 nex-packet/src/builder/ethernet.rs create mode 100644 nex-packet/src/builder/icmp.rs create mode 100644 nex-packet/src/builder/icmpv6.rs create mode 100644 nex-packet/src/builder/ipv4.rs create mode 100644 nex-packet/src/builder/ipv6.rs rename nex-packet-builder/src/lib.rs => nex-packet/src/builder/mod.rs (58%) create mode 100644 nex-packet/src/builder/ndp.rs create mode 100644 nex-packet/src/builder/tcp.rs create mode 100644 nex-packet/src/builder/udp.rs create mode 100644 nex-packet/src/flowcontrol.rs create mode 100644 nex-packet/src/packet.rs delete mode 100644 nex-packet/src/sll.rs delete mode 100644 nex-packet/src/sll2.rs delete mode 100644 nex-packet/src/usbpcap.rs create mode 100644 nex-packet/src/vxlan.rs create mode 100644 nex-socket/src/icmp/async_impl.rs create mode 100644 nex-socket/src/icmp/config.rs create mode 100644 nex-socket/src/icmp/mod.rs create mode 100644 nex-socket/src/icmp/sync_impl.rs delete mode 100644 nex-socket/src/socket/async_impl.rs delete mode 100644 nex-socket/src/socket/mod.rs delete mode 100644 nex-socket/src/socket/sync_impl.rs delete mode 100644 nex-socket/src/sys/mod.rs delete mode 100644 nex-socket/src/sys/unix.rs delete mode 100644 nex-socket/src/sys/windows.rs create mode 100644 nex-socket/src/tcp/async_impl.rs create mode 100644 nex-socket/src/tcp/config.rs create mode 100644 nex-socket/src/tcp/mod.rs create mode 100644 nex-socket/src/tcp/sync_impl.rs create mode 100644 nex-socket/src/udp/async_impl.rs create mode 100644 nex-socket/src/udp/config.rs create mode 100644 nex-socket/src/udp/mod.rs create mode 100644 nex-socket/src/udp/sync_impl.rs create mode 100644 scripts/build-all.ps1 create mode 100755 scripts/build-all.sh diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..c194f0d --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +scripts/* linguist-vendored diff --git a/.gitignore b/.gitignore index e926af6..68b693c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,16 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb +# Generated by cargo mutants +# Contains mutation testing data +**/mutants.out*/ + +# RustRover +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + # macOS *.DS_Store diff --git a/Cargo.toml b/Cargo.toml index df836ea..66623d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,32 +1,28 @@ [workspace] resolver = "2" members = [ - "nex", - "nex-core", - "nex-datalink", - "nex-macro", - "nex-macro-helper", - "nex-packet", - "nex-socket", - "nex-sys", - "nex-packet-builder" + "nex", + "nex-core", + "nex-datalink", + "nex-packet", + "nex-socket", + "nex-sys" ] [workspace.package] -version = "0.19.1" +version = "0.20.0" edition = "2021" authors = ["shellrow "] [workspace.dependencies] -nex-core = { version = "0.19.1", path = "nex-core" } -nex-datalink = { version = "0.19.1", path = "nex-datalink" } -nex-macro = { version = "0.19.1", path = "nex-macro" } -nex-macro-helper = { version = "0.19.1", path = "nex-macro-helper" } -nex-packet = { version = "0.19.1", path = "nex-packet" } -nex-packet-builder = { version = "0.19.1", path = "nex-packet-builder" } -nex-socket = { version = "0.19.1", path = "nex-socket" } -nex-sys = { version = "0.19.1", path = "nex-sys" } +nex-core = { version = "0.20.0", path = "nex-core" } +nex-datalink = { version = "0.20.0", path = "nex-datalink" } +nex-packet = { version = "0.20.0", path = "nex-packet" } +nex-sys = { version = "0.20.0", path = "nex-sys" } +nex-socket = { version = "0.20.0", path = "nex-socket" } serde = { version = "1" } libc = "0.2" +netdev = { version = "0.36" } +bytes = "1" +tokio = { version = "1" } rand = "0.8" -netdev = { version = "0.34" } diff --git a/Cross.toml b/Cross.toml deleted file mode 100644 index 4131c2a..0000000 --- a/Cross.toml +++ /dev/null @@ -1,9 +0,0 @@ -[build] -build-std = false # do not build the std library. has precedence over xargo -xargo = false # do not use xargo for the builds -zig = false # do not use zig cc for the builds -default-target = "x86_64-unknown-linux-gnu" # use this target if none is explicitly provided -pre-build = [ # additional commands to run prior to building the package - "dpkg --add-architecture $CROSS_DEB_ARCH", - "apt-get update && apt-get --assume-yes install libssl-dev:$CROSS_DEB_ARCH" -] diff --git a/LICENSE b/LICENSE index cf93579..7a8f511 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 shellrow +Copyright (c) 2023-2025 shellrow Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 50ae131..741a5e7 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,7 @@ Cross-platform low-level networking library in Rust It includes a set of modules, each with a specific focus: - `datalink`: Datalink layer networking. -- `packet`: Low-level packet parsing and building. -- `packet-builder`: High-level packet building. +- `packet`: Low-level packet parsing and building. - `socket`: Socket-related functionality. ## Upcoming Features @@ -27,14 +26,13 @@ To use `nex`, add it as a dependency in your `Cargo.toml`: ```toml [dependencies] -nex = "0.19" +nex = "0.20" ``` ## Using Specific Sub-crates You can also directly use specific sub-crates by importing them individually. - `nex-datalink` - `nex-packet` -- `nex-packet-builder` - `nex-socket` If you want to focus on network interfaces, you can use the [netdev](https://github.com/shellrow/netdev). @@ -55,13 +53,3 @@ Please note that in order to send and receive raw packets using `nex-datalink` o On macOS, managing access to the Berkeley Packet Filter (BPF) devices is necessary for send and receive raw packets using `nex-datalink`. You can use [chmod-bpf](https://github.com/shellrow/chmod-bpf) to automatically manage permissions for BPF devices. Alternatively, of course, you can also use `sudo` to temporarily grant the necessary permissions. - -## Build time requirements for optional feature -The cryptography provider for `nex-socket`'s optional `tls-aws-lc` feature use `aws-lc-rs`. Note that this has some implications on [build-time tool requirements](https://aws.github.io/aws-lc-rs/requirements/index.html), such as requiring cmake on all platforms and nasm on Windows. -**You can use `ring` as the cryptography provider (without additional dependencies) by specifying the `tls` feature.** - -## Acknowledgment -This library was heavily inspired by `pnet`, which catalyzed my journey into Rust development. -I am grateful to everyone involved in `pnet` for their pioneering efforts and significant contributions to networking in Rust. - -Additionally, thank you to all contributors and maintainers of the projects `nex` depends on for your invaluable work and support. diff --git a/examples/arp.rs b/examples/arp.rs index 4bbcdb7..c03ee77 100644 --- a/examples/arp.rs +++ b/examples/arp.rs @@ -1,132 +1,105 @@ -//! This example sends ARP request packet to the target and waits for ARP reply packets. +//! Sends ARP request to the target and waits for ARP reply. //! -//! e.g. +//! Usage: +//! arp //! -//! arp 192.168.1.1 eth0 +//! Example: +//! arp 192.168.1.1 eth0 use nex::datalink; use nex::datalink::Channel::Ethernet; -use nex::net::interface::Interface; +use nex::net::interface::{get_interfaces, Interface}; use nex::net::mac::MacAddr; use nex::packet::ethernet::EtherType; -use nex::packet::frame::Frame; -use nex::packet::frame::ParseOption; -use nex::util::packet_builder::builder::PacketBuilder; -use nex::util::packet_builder::ethernet::EthernetPacketBuilder; +use nex::packet::frame::{Frame, ParseOption}; +use nex::packet::builder::ethernet::EthernetPacketBuilder; use nex_packet::arp::ArpOperation; -use nex_packet_builder::arp::ArpPacketBuilder; +use nex_packet::builder::arp::ArpPacketBuilder; +use nex_packet::packet::Packet; use std::env; -use std::net::IpAddr; -use std::net::Ipv4Addr; +use std::net::{IpAddr, Ipv4Addr}; use std::process; -const USAGE: &str = "USAGE: arp "; - fn main() { - let interface: Interface = match env::args().nth(2) { - Some(n) => { - // Use interface specified by the user - let interfaces: Vec = nex::net::interface::get_interfaces(); - let interface: Interface = interfaces - .into_iter() - .find(|interface| interface.name == n) - .expect("Failed to get interface information"); - interface - } - None => { - // Use the default interface - match Interface::default() { - Ok(interface) => interface, - Err(e) => { - println!("Failed to get default interface: {}", e); - process::exit(1); - } - } + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: arp "); + process::exit(1); + } + + let target_ip: Ipv4Addr = match args[1].parse() { + Ok(IpAddr::V4(ipv4)) => ipv4, + Ok(_) => { + eprintln!("IPv6 is not supported. Use ndp instead."); + process::exit(1); } - }; - let dst_ip: Ipv4Addr = match env::args().nth(1) { - Some(target_ip) => match target_ip.parse::() { - Ok(ip) => match ip { - IpAddr::V4(ipv4) => ipv4, - IpAddr::V6(_ipv6) => { - println!("IPv6 is not supported. Use ndp instead."); - eprintln!("{USAGE}"); - process::exit(1); - } - }, - Err(e) => { - println!("Failed to parse target ip: {}", e); - eprintln!("{USAGE}"); - process::exit(1); - } - }, - None => { - println!("Failed to get target ip"); - eprintln!("{USAGE}"); + Err(e) => { + eprintln!("Failed to parse target IP: {}", e); process::exit(1); } }; - let src_ip: Ipv4Addr = interface.ipv4[0].addr(); + let interface = match env::args().nth(2) { + Some(name) => get_interfaces() + .into_iter() + .find(|i| i.name == name) + .expect("Failed to get interface"), + None => Interface::default().expect("Failed to get default interface"), + }; + + let src_mac = interface.mac_addr.clone().expect("No MAC address on interface"); + let src_ip = interface.ipv4.get(0).expect("No IPv4 address").addr(); - // Create a channel to send/receive packet let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { Ok(Ethernet(tx, rx)) => (tx, rx), - Ok(_) => panic!("parse_frame: unhandled channel type"), - Err(e) => panic!("parse_frame: unable to create channel: {}", e), + Ok(_) => panic!("Unhandled channel type"), + Err(e) => panic!("Failed to create channel: {}", e), }; - // Packet builder for ARP Request - let mut packet_builder = PacketBuilder::new(); - let ethernet_packet_builder = EthernetPacketBuilder { - src_mac: interface.mac_addr.clone().unwrap(), - dst_mac: MacAddr::broadcast(), - ether_type: EtherType::Arp, - }; - packet_builder.set_ethernet(ethernet_packet_builder); + let eth_builder = EthernetPacketBuilder::new() + .source(src_mac) + .destination(MacAddr::broadcast()) + .ethertype(EtherType::Arp); - let arp_packet = ArpPacketBuilder { - src_mac: interface.mac_addr.clone().unwrap(), - dst_mac: MacAddr::broadcast(), - src_ip: src_ip, - dst_ip: dst_ip, - }; - packet_builder.set_arp(arp_packet); + let arp_builder = ArpPacketBuilder::new(src_mac, src_ip, target_ip); + + let packet = eth_builder + .payload(arp_builder.build().to_bytes()) + .build(); - // Send ARP Request packet - match tx.send(&packet_builder.packet()) { - Some(_) => println!("ARP Packet sent"), - None => println!("Failed to send packet"), + match tx.send(&packet.to_bytes()) { + Some(_) => println!("ARP Request sent to {}", target_ip), + None => { + eprintln!("Failed to send ARP packet"); + process::exit(1); + } } - // Receive ARP Reply packet - println!("Waiting for ARP Reply packet..."); + println!("Waiting for ARP Reply..."); loop { match rx.next() { Ok(packet) => { - let parse_option: ParseOption = ParseOption::default(); - let frame: Frame = Frame::from_bytes(&packet, parse_option); - if let Some(datalik_layer) = &frame.datalink { - if let Some(arp_packet) = &datalik_layer.arp { - if arp_packet.operation == ArpOperation::Reply { - println!("ARP Reply packet received"); - println!( - "Received ARP Reply packet from {}", - arp_packet.sender_proto_addr - ); - println!("MAC address: {}", arp_packet.sender_hw_addr); - println!( - "---- Interface: {}, Total Length: {} bytes ----", - interface.name, - packet.len() - ); - println!("Packet Frame: {:?}", frame); - break; + let frame = Frame::from_buf(&packet, ParseOption::default()).unwrap(); + match &frame.datalink { + Some(dlink) => { + if let Some(arp) = &dlink.arp { + if arp.operation == ArpOperation::Reply && arp.sender_proto_addr == target_ip { + println!("Received ARP Reply from {}", arp.sender_proto_addr); + println!("MAC address: {}", arp.sender_hw_addr); + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Frame: {:?}", frame); + break; + } } } + None => continue, // No datalink layer } } - Err(e) => println!("Failed to receive packet: {}", e), + Err(e) => eprintln!("Receive failed: {}", e), } } } diff --git a/examples/async_icmp_socket.rs b/examples/async_icmp_socket.rs new file mode 100644 index 0000000..ac55cd4 --- /dev/null +++ b/examples/async_icmp_socket.rs @@ -0,0 +1,91 @@ +//! Simple IPv4 ping scanner using AsyncIcmpSocket +//! +//! Usage: async_icmp_socket +//! Example: async_icmp_socket 192.168.1 + +use bytes::Bytes; +use nex_socket::icmp::{AsyncIcmpSocket, IcmpConfig, IcmpKind}; +use nex_packet::builder::icmp::IcmpPacketBuilder; +use nex_packet::icmp::{self, IcmpType}; +use std::collections::HashMap; +use std::env; +use nex::net::interface::{Interface, get_interfaces}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::Duration; +use rand::{Rng, thread_rng}; +use tokio::sync::Mutex; +use tokio::time; + +#[tokio::main] +async fn main() -> std::io::Result<()> { + let prefix = env::args().nth(1).expect("prefix like 192.168.1"); + let parts: Vec = prefix.split('.').map(|s| s.parse().expect("num")).collect(); + assert!(parts.len() == 3, "prefix must be a.b.c"); + + let interface = match env::args().nth(2) { + Some(name) => get_interfaces().into_iter().find(|i| i.name == name).expect("interface not found"), + None => Interface::default().expect("default interface"), + }; + + let src_ip = interface + .ipv4 + .get(0) + .map(|v| v.addr()) + .expect("No IPv4 address on interface"); + + let config = IcmpConfig::new(IcmpKind::V4); + let socket = Arc::new(AsyncIcmpSocket::new(&config).await.unwrap()); + + // map from (id, seq) to target IP + let replies = Arc::new(Mutex::new(HashMap::new())); + + // Receiver task + let socket_clone = socket.clone(); + + tokio::spawn(async move { + let mut buf = [0u8; 512]; + loop { + if let Ok((n, from)) = socket_clone.recv_from(&mut buf).await { + println!("Received {} bytes from {}", n, from.ip()); + } + } + }); + + let mut handles = Vec::new(); + for i in 1u8..=254 { + let addr = Ipv4Addr::new(parts[0], parts[1], parts[2], i); + let id: u16 = thread_rng().gen(); + let seq: u16 = 1; + let socket = socket.clone(); + let replies = replies.clone(); + + handles.push(tokio::spawn(async move { + let pkt = IcmpPacketBuilder::new(src_ip, addr) + .icmp_type(IcmpType::EchoRequest) + .icmp_code(icmp::echo_request::IcmpCodes::NoCode) + .echo_fields(id, seq) + .payload(Bytes::from_static(b"ping")) + .culculate_checksum() + .to_bytes(); + let target = SocketAddr::new(IpAddr::V4(addr), 0); + let _ = socket.send_to(&pkt, target).await; + { + let mut lock = replies.lock().await; + lock.insert((id, seq), addr); + } + time::sleep(Duration::from_millis(500)).await; + let mut lock = replies.lock().await; + if lock.remove(&(id, seq)).is_some() { + // already handled in receiver + } else { + println!("{} is not responding", addr); + } + })); + } + + for h in handles { + let _ = h.await; + } + Ok(()) +} diff --git a/examples/async_tcp_connect.rs b/examples/async_tcp_connect.rs deleted file mode 100644 index 895c034..0000000 --- a/examples/async_tcp_connect.rs +++ /dev/null @@ -1,62 +0,0 @@ -use futures::stream::{self, StreamExt}; -use nex_socket::AsyncSocket; -use std::net::SocketAddr; -use std::str::FromStr; -use std::time::Duration; - -fn main() { - // List of destination for TCP connect test. - let dst_sockets = vec![ - "1.0.0.2:53", - "1.0.0.3:53", - "1.1.1.1:53", - "1.1.1.3:53", - "4.2.2.6:53", - "8.0.7.0:53", - "8.8.8.8:53", - "77.88.8.1:53", - "77.88.8.3:53", - "77.88.8.88:53", - ]; - let conn_timeout = Duration::from_millis(300); - let concurrency: usize = 10; - let start_time = std::time::Instant::now(); - async_io::block_on(async { - let fut = stream::iter(dst_sockets).for_each_concurrent( - concurrency, - |socket_addr_str| async move { - let socket_addr: SocketAddr = SocketAddr::from_str(socket_addr_str).unwrap(); - let conn_start_time = std::time::Instant::now(); - match AsyncSocket::new_with_async_connect_timeout(&socket_addr, conn_timeout).await - { - Ok(async_socket) => { - let local_socket_addr = async_socket.local_addr().await.unwrap(); - let remote_socket_addr = async_socket.peer_addr().await.unwrap(); - println!( - "Connected {} -> {} in {}ms", - local_socket_addr, - remote_socket_addr, - conn_start_time.elapsed().as_millis() - ); - match async_socket.shutdown(std::net::Shutdown::Both).await { - Ok(_) => { - println!( - "Connection closed ({} -> {})", - local_socket_addr, remote_socket_addr - ); - } - Err(e) => { - println!("shutdown error (for {}): {}", socket_addr, e); - } - } - } - Err(e) => { - println!("connection error (for {}): {}", socket_addr, e); - } - } - }, - ); - fut.await; - }); - println!("Total time: {}ms", start_time.elapsed().as_millis()); -} diff --git a/examples/async_tcp_socket.rs b/examples/async_tcp_socket.rs new file mode 100644 index 0000000..5ffb183 --- /dev/null +++ b/examples/async_tcp_socket.rs @@ -0,0 +1,32 @@ +//! Simple TCP port scanner using AsyncTcpSocket +//! +//! Usage: async_tcp_socket ... + +use nex_socket::tcp::{AsyncTcpSocket, TcpConfig}; +use std::env; +use std::net::{IpAddr, SocketAddr}; +use std::time::Duration; + +#[tokio::main] +async fn main() -> std::io::Result<()> { + let mut args = env::args().skip(1); + let ip: IpAddr = args.next().expect("IP").parse().expect("ip"); + let ports: Vec = args.map(|p| p.parse().expect("port")).collect(); + + let mut handles = Vec::new(); + for port in ports { + let addr = SocketAddr::new(ip, port); + handles.push(tokio::spawn(async move { + let cfg = if ip.is_ipv4() { TcpConfig::v4_stream() } else { TcpConfig::v6_stream() }; + let sock = AsyncTcpSocket::from_config(&cfg).unwrap(); + match sock.connect_timeout(addr, Duration::from_millis(500)).await { + Ok(_) => println!("Port {} is open", port), + Err(e) => println!("Port {} is closed: {}", port, e), + } + })); + } + for h in handles { + let _ = h.await; + } + Ok(()) +} diff --git a/examples/async_tcp_stream.rs b/examples/async_tcp_stream.rs deleted file mode 100644 index 77a682f..0000000 --- a/examples/async_tcp_stream.rs +++ /dev/null @@ -1,45 +0,0 @@ -use nex_socket::AsyncTcpStream; -use std::{ - net::{IpAddr, Ipv4Addr, Shutdown, SocketAddr}, - time::Duration, -}; - -fn main() { - let ip_addr: IpAddr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); - println!("Connecting to 1.1.1.1:80 ..."); - async_io::block_on(async { - match AsyncTcpStream::connect_timeout( - &SocketAddr::new(ip_addr, 80), - Duration::from_millis(200), - ) - .await - { - Ok(stream) => { - println!("Connected to 1.1.1.1:80"); - let req = format!("GET / HTTP/1.1\r\nHost: {}\r\n\r\n", ip_addr.to_string()); - println!("Sending data (HTTP Request) ..."); - match stream.write(req.as_bytes()).await { - Ok(n) => println!("{} bytes sent (payload)", n), - Err(e) => println!("{}", e), - } - let mut res = vec![0; 1024]; - println!("Receiving data ..."); - match stream.read(&mut res).await { - Ok(n) => { - println!("{} bytes received (HTTP Response):", n); - println!("----------------------------------------"); - println!("{}", String::from_utf8_lossy(&res[..n])); - println!("----------------------------------------"); - } - Err(e) => println!("{}", e), - } - println!("Closing socket ..."); - match stream.shutdown(Shutdown::Both).await { - Ok(_) => println!("Socket closed"), - Err(e) => println!("{}", e), - } - } - Err(e) => println!("{}", e), - } - }); -} diff --git a/examples/dump.rs b/examples/dump.rs index eed9dba..c3bef37 100644 --- a/examples/dump.rs +++ b/examples/dump.rs @@ -1,19 +1,22 @@ //! Basic packet capture using nex +use bytes::Bytes; use nex::datalink; use nex::datalink::Channel::Ethernet; use nex::net::interface::Interface; use nex::net::mac::MacAddr; use nex::packet::arp::ArpPacket; -use nex::packet::ethernet::{EtherType, EthernetPacket, MutableEthernetPacket}; -use nex::packet::icmp::{echo_reply, echo_request, IcmpPacket, IcmpType}; +use nex::packet::ethernet::{EtherType, EthernetPacket}; +use nex::packet::icmp::{IcmpPacket, IcmpType}; use nex::packet::icmpv6::Icmpv6Packet; -use nex::packet::ip::IpNextLevelProtocol; +use nex::packet::ip::IpNextProtocol; use nex::packet::ipv4::Ipv4Packet; use nex::packet::ipv6::Ipv6Packet; use nex::packet::tcp::TcpPacket; use nex::packet::udp::UdpPacket; -use nex::packet::Packet; +use nex::packet::packet::Packet; +use nex_packet::ethernet::EthernetHeader; +use nex_packet::{icmp, icmpv6}; use std::env; use std::net::IpAddr; use std::process; @@ -49,8 +52,6 @@ fn main() { }; let mut capture_no: usize = 0; loop { - let mut buf: [u8; 4096] = [0u8; 4096]; - let mut fake_ethernet_frame = MutableEthernetPacket::new(&mut buf[..]).unwrap(); match rx.next() { Ok(packet) => { capture_no += 1; @@ -60,38 +61,34 @@ fn main() { capture_no, packet.len() ); - let payload_offset; + if interface.is_tun() || (cfg!(any(target_os = "macos", target_os = "ios")) && interface.is_loopback()) { + let payload_offset: usize; if interface.is_loopback() { payload_offset = 14; } else { payload_offset = 0; } + let payload = Bytes::copy_from_slice(&packet[payload_offset..]); if packet.len() > payload_offset { - let version = Ipv4Packet::new(&packet[payload_offset..]) + let version = Ipv4Packet::from_buf(packet) .unwrap() - .get_version(); - if version == 4 { - fake_ethernet_frame.set_destination(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_source(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_ethertype(EtherType::Ipv4); - fake_ethernet_frame.set_payload(&packet[payload_offset..]); - handle_ethernet_frame(&fake_ethernet_frame.to_immutable()); - continue; - } else if version == 6 { - fake_ethernet_frame.set_destination(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_source(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_ethertype(EtherType::Ipv6); - fake_ethernet_frame.set_payload(&packet[payload_offset..]); - handle_ethernet_frame(&fake_ethernet_frame.to_immutable()); - continue; - } + .header.version; + let fake_eth = EthernetPacket { + header: EthernetHeader { + destination: MacAddr::zero(), + source: MacAddr::zero(), + ethertype: if version == 4 { EtherType::Ipv4 } else { EtherType::Ipv6 }, + }, + payload, + }; + handle_ethernet_frame(fake_eth); } } else { - handle_ethernet_frame(&EthernetPacket::new(packet).unwrap()); + handle_ethernet_frame(EthernetPacket::from_buf(packet).unwrap()); } } Err(e) => panic!("dump: unable to receive packet: {}", e), @@ -99,63 +96,61 @@ fn main() { } } -fn handle_ethernet_frame(ethernet: &EthernetPacket) { - match ethernet.get_ethertype() { - EtherType::Ipv4 => handle_ipv4_packet(ethernet), - EtherType::Ipv6 => handle_ipv6_packet(ethernet), - EtherType::Arp => handle_arp_packet(ethernet), +fn handle_ethernet_frame(ethernet: EthernetPacket) { + let total_len = ethernet.total_len(); + let (header, payload) = ethernet.into_parts(); + match header.ethertype { + EtherType::Ipv4 => handle_ipv4_packet(payload), + EtherType::Ipv6 => handle_ipv6_packet(payload), + EtherType::Arp => handle_arp_packet(payload), _ => { - let ether_type = ethernet.get_ethertype(); println!( "{} packet: {} > {}; ethertype: {:?} length: {}", - ether_type.name(), - ethernet.get_source(), - ethernet.get_destination(), - ethernet.get_ethertype(), - ethernet.packet().len() + header.ethertype.name(), + header.source, + header.destination, + header.ethertype, + total_len, ) } } } -fn handle_arp_packet(ethernet: &EthernetPacket) { - let header = ArpPacket::new(ethernet.payload()); - if let Some(header) = header { +fn handle_arp_packet(packet: Bytes) { + if let Some(arp) = ArpPacket::from_bytes(packet) { println!( "ARP packet: {}({}) > {}({}); operation: {:?}", - ethernet.get_source(), - header.get_sender_proto_addr(), - ethernet.get_destination(), - header.get_target_proto_addr(), - header.get_operation() + arp.header.sender_hw_addr, + arp.header.sender_proto_addr, + arp.header.target_hw_addr, + arp.header.target_proto_addr, + arp.header.operation ); } else { println!("Malformed ARP Packet"); } } -fn handle_ipv4_packet(ethernet: &EthernetPacket) { - let header = Ipv4Packet::new(ethernet.payload()); - if let Some(header) = header { +fn handle_ipv4_packet(packet: Bytes) { + if let Some(ipv4) = Ipv4Packet::from_bytes(packet) { handle_transport_protocol( - IpAddr::V4(header.get_source()), - IpAddr::V4(header.get_destination()), - header.get_next_level_protocol(), - header.payload(), + IpAddr::V4(ipv4.header.source), + IpAddr::V4(ipv4.header.destination), + ipv4.header.next_level_protocol, + ipv4.payload, ); } else { println!("Malformed IPv4 Packet"); } } -fn handle_ipv6_packet(ethernet: &EthernetPacket) { - let header = Ipv6Packet::new(ethernet.payload()); - if let Some(header) = header { +fn handle_ipv6_packet(packet: Bytes) { + if let Some(ipv6) = Ipv6Packet::from_bytes(packet) { handle_transport_protocol( - IpAddr::V6(header.get_source()), - IpAddr::V6(header.get_destination()), - header.get_next_header(), - header.payload(), + IpAddr::V6(ipv6.header.source), + IpAddr::V6(ipv6.header.destination), + ipv6.header.next_header, + ipv6.payload, ); } else { println!("Malformed IPv6 Packet"); @@ -165,14 +160,14 @@ fn handle_ipv6_packet(ethernet: &EthernetPacket) { fn handle_transport_protocol( source: IpAddr, destination: IpAddr, - protocol: IpNextLevelProtocol, - packet: &[u8], + protocol: IpNextProtocol, + packet: Bytes, ) { match protocol { - IpNextLevelProtocol::Tcp => handle_tcp_packet(source, destination, packet), - IpNextLevelProtocol::Udp => handle_udp_packet(source, destination, packet), - IpNextLevelProtocol::Icmp => handle_icmp_packet(source, destination, packet), - IpNextLevelProtocol::Icmpv6 => handle_icmpv6_packet(source, destination, packet), + IpNextProtocol::Tcp => handle_tcp_packet(source, destination, packet), + IpNextProtocol::Udp => handle_udp_packet(source, destination, packet), + IpNextProtocol::Icmp => handle_icmp_packet(source, destination, packet), + IpNextProtocol::Icmpv6 => handle_icmpv6_packet(source, destination, packet), _ => println!( "Unknown {} packet: {} > {}; protocol: {:?} length: {}", match source { @@ -187,88 +182,155 @@ fn handle_transport_protocol( } } -fn handle_tcp_packet(source: IpAddr, destination: IpAddr, packet: &[u8]) { - let tcp = TcpPacket::new(packet); - if let Some(tcp) = tcp { +fn handle_tcp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + if let Some(tcp) = TcpPacket::from_bytes(packet) { println!( "TCP Packet: {}:{} > {}:{}; length: {}", source, - tcp.get_source(), + tcp.header.source, destination, - tcp.get_destination(), - packet.len() + tcp.header.destination, + tcp.total_len(), ); } else { println!("Malformed TCP Packet"); } } -fn handle_udp_packet(source: IpAddr, destination: IpAddr, packet: &[u8]) { - let udp = UdpPacket::new(packet); +fn handle_udp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let udp = UdpPacket::from_bytes(packet); if let Some(udp) = udp { println!( "UDP Packet: {}:{} > {}:{}; length: {}", source, - udp.get_source(), + udp.header.source, destination, - udp.get_destination(), - udp.get_length() + udp.header.destination, + udp.total_len(), ); } else { println!("Malformed UDP Packet"); } } -fn handle_icmp_packet(source: IpAddr, destination: IpAddr, packet: &[u8]) { - let icmp_packet = IcmpPacket::new(packet); +fn handle_icmp_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let icmp_packet = IcmpPacket::from_bytes(packet); if let Some(icmp_packet) = icmp_packet { - match icmp_packet.get_icmp_type() { + let total_len = icmp_packet.total_len(); + match icmp_packet.header.icmp_type { + IcmpType::EchoRequest => { + let echo_request_packet = icmp::echo_request::EchoRequestPacket::try_from(icmp_packet).unwrap(); + println!( + "ICMP echo request {} -> {} (seq={:?}, id={:?}), length: {}", + source, + destination, + echo_request_packet.sequence_number, + echo_request_packet.identifier, + total_len + ); + } IcmpType::EchoReply => { - let echo_reply_packet = echo_reply::EchoReplyPacket::new(packet).unwrap(); + let echo_reply_packet = icmp::echo_reply::EchoReplyPacket::try_from(icmp_packet).unwrap(); println!( "ICMP echo reply {} -> {} (seq={:?}, id={:?}), length: {}", source, destination, - echo_reply_packet.get_sequence_number(), - echo_reply_packet.get_identifier(), - packet.len() + echo_reply_packet.sequence_number, + echo_reply_packet.identifier, + total_len, ); } - IcmpType::EchoRequest => { - let echo_request_packet = echo_request::EchoRequestPacket::new(packet).unwrap(); + IcmpType::DestinationUnreachable => { + let unreachable_packet = icmp::destination_unreachable::DestinationUnreachablePacket::try_from(icmp_packet).unwrap(); println!( - "ICMP echo request {} -> {} (seq={:?}, id={:?}), length: {}", + "ICMP destination unreachable {} -> {} (code={:?}), next_hop_mtu={}, length: {}", source, destination, - echo_request_packet.get_sequence_number(), - echo_request_packet.get_identifier(), - packet.len() + unreachable_packet.header.icmp_code, + unreachable_packet.next_hop_mtu, + total_len ); } - _ => println!( - "ICMP packet {} -> {} (type={:?}), length: {}", - source, - destination, - icmp_packet.get_icmp_type(), - packet.len() - ), + IcmpType::TimeExceeded => { + let time_exceeded_packet = icmp::time_exceeded::TimeExceededPacket::try_from(icmp_packet).unwrap(); + println!( + "ICMP time exceeded {} -> {} (code={:?}), length: {}", + source, + destination, + time_exceeded_packet.header.icmp_code, + total_len + ); + } + _ => { + println!( + "ICMP packet {} -> {} (type={:?}), length: {}", + source, + destination, + icmp_packet.header.icmp_type, + total_len + ) + } } } else { println!("Malformed ICMP Packet"); } } -fn handle_icmpv6_packet(source: IpAddr, destination: IpAddr, packet: &[u8]) { - let icmpv6_packet = Icmpv6Packet::new(packet); +fn handle_icmpv6_packet(source: IpAddr, destination: IpAddr, packet: Bytes) { + let icmpv6_packet = Icmpv6Packet::from_bytes(packet); if let Some(icmpv6_packet) = icmpv6_packet { - println!( - "ICMPv6 packet {} -> {} (type={:?}), length: {}", - source, - destination, - icmpv6_packet.get_icmpv6_type(), - packet.len() - ) + match icmpv6_packet.header.icmpv6_type { + nex::packet::icmpv6::Icmpv6Type::EchoRequest => { + let echo_request_packet = icmpv6::echo_request::EchoRequestPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 echo request {} -> {} (type={:?}), length: {}", + source, + destination, + echo_request_packet.header.icmpv6_type, + echo_request_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::EchoReply => { + let echo_reply_packet = icmpv6::echo_reply::EchoReplyPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 echo reply {} -> {} (type={:?}), length: {}", + source, + destination, + echo_reply_packet.header.icmpv6_type, + echo_reply_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::NeighborSolicitation => { + let ns_packet = icmpv6::ndp::NeighborSolicitPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 neighbor solicitation {} -> {} (type={:?}), length: {}", + source, + destination, + ns_packet.header.icmpv6_type, + ns_packet.total_len(), + ); + } + nex::packet::icmpv6::Icmpv6Type::NeighborAdvertisement => { + let na_packet = icmpv6::ndp::NeighborAdvertPacket::try_from(icmpv6_packet).unwrap(); + println!( + "ICMPv6 neighbor advertisement {} -> {} (type={:?}), length: {}", + source, + destination, + na_packet.header.icmpv6_type, + na_packet.total_len(), + ); + } + _ => { + println!( + "ICMPv6 packet {} -> {} (type={:?}), length: {}", + source, + destination, + icmpv6_packet.header.icmpv6_type, + icmpv6_packet.total_len(), + ) + } + } } else { println!("Malformed ICMPv6 Packet"); } diff --git a/examples/icmp_ping.rs b/examples/icmp_ping.rs index 65b7182..e8b5707 100644 --- a/examples/icmp_ping.rs +++ b/examples/icmp_ping.rs @@ -1,181 +1,153 @@ -//! This example sends ICMP Echo request packet to the target socket and waits for ICMP Echo reply packet. +//! Sends ICMP Echo Request and waits for ICMP Echo Reply. //! -//! e.g. +//! Usage: +//! icmp_ping //! -//! IPv4: icmp_ping 1.1.1.1 eth0 -//! -//! IPv6: icmp_ping "2606:4700:4700::1111" eth0 +//! Example: +//! IPv4: icmp_ping 1.1.1.1 eth0 +//! IPv6: icmp_ping "2606:4700:4700::1111" eth0 +use bytes::Bytes; use nex::datalink; use nex::datalink::Channel::Ethernet; use nex::net::interface::Interface; use nex::net::mac::MacAddr; use nex::packet::ethernet::EtherType; -use nex::packet::frame::Frame; -use nex::packet::frame::ParseOption; +use nex::packet::frame::{Frame, ParseOption}; +use nex::packet::builder::icmp::IcmpPacketBuilder; +use nex::packet::builder::icmpv6::Icmpv6PacketBuilder; use nex::packet::icmp::IcmpType; -use nex::packet::ip::IpNextLevelProtocol; -use nex::util::packet_builder::builder::PacketBuilder; -use nex::util::packet_builder::ethernet::EthernetPacketBuilder; -use nex::util::packet_builder::icmp::IcmpPacketBuilder; -use nex::util::packet_builder::icmpv6::Icmpv6PacketBuilder; -use nex::util::packet_builder::ipv4::Ipv4PacketBuilder; -use nex::util::packet_builder::ipv6::Ipv6PacketBuilder; -use nex_packet::icmpv6::Icmpv6Type; +use nex::packet::icmpv6::Icmpv6Type; +use nex::packet::builder::ethernet::EthernetPacketBuilder; +use nex::packet::builder::ipv4::Ipv4PacketBuilder; +use nex::packet::builder::ipv6::Ipv6PacketBuilder; +use nex_packet::{icmp, icmpv6}; +use nex_packet::ip::IpNextProtocol; +use nex_packet::ipv4::Ipv4Flags; +use nex_packet::packet::Packet; use std::env; use std::net::IpAddr; -use std::net::Ipv6Addr; -use std::process; -const USAGE: &str = "USAGE: icmp_ping "; +fn main() { + let interface = match env::args().nth(2) { + Some(name) => nex::net::interface::get_interfaces() + .into_iter() + .find(|i| i.name == name) + .expect("Failed to get interface"), + None => Interface::default().expect("Failed to get default interface"), + }; + let use_tun = interface.is_tun(); -fn get_global_ipv6(interface: &Interface) -> Option { - interface - .ipv6 - .iter() - .find(|ipv6| nex::net::ip::is_global_ipv6(&ipv6.addr())) - .map(|ipv6| ipv6.addr()) -} + let target_ip: IpAddr = env::args() + .nth(1) + .expect("Missing target IP") + .parse() + .expect("Failed to parse target IP"); -fn main() { - let interface: Interface = match env::args().nth(2) { - Some(n) => { - // Use interface specified by the user - let interfaces: Vec = nex::net::interface::get_interfaces(); - let interface: Interface = interfaces - .into_iter() - .find(|interface| interface.name == n) - .expect("Failed to get interface information"); - interface - } - None => { - // Use the default interface - match Interface::default() { - Ok(interface) => interface, - Err(e) => { - println!("Failed to get default interface: {}", e); - process::exit(1); - } - } - } + let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { + Ok(Ethernet(tx, rx)) => (tx, rx), + Ok(_) => panic!("Unhandled channel type"), + Err(e) => panic!("Failed to create channel: {}", e), }; - let dst_ip: IpAddr = match env::args().nth(1) { - Some(target_ip) => match target_ip.parse::() { - Ok(ip) => ip, - Err(e) => { - println!("Failed to parse target ip: {}", e); - eprintln!("{USAGE}"); - process::exit(1); - } - }, - None => { - println!("Failed to get target ip"); - eprintln!("{USAGE}"); - process::exit(1); - } + + let src_ip: IpAddr = match target_ip { + IpAddr::V4(_) => interface + .ipv4 + .get(0) + .map(|v| IpAddr::V4(v.addr())) + .expect("No IPv4 address"), + IpAddr::V6(_) => interface + .ipv6 + .iter() + .find(|v| nex::net::ip::is_global_ipv6(&v.addr())) + .map(|v| IpAddr::V6(v.addr())) + .expect("No global IPv6 address"), }; - let use_tun: bool = interface.is_tun(); - let src_ip: IpAddr = match dst_ip { - IpAddr::V4(_) => interface.ipv4[0].addr().into(), - IpAddr::V6(_) => { - let ipv6 = get_global_ipv6(&interface).expect("Failed to get global IPv6 address"); - ipv6.into() - } + + let icmp_packet: Bytes = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => IcmpPacketBuilder::new(src, dst) + .icmp_type(IcmpType::EchoRequest) + .icmp_code(icmp::echo_request::IcmpCodes::NoCode) + .echo_fields(0x1234, 0x1) + .payload(Bytes::from_static(b"hello")) + .culculate_checksum() + .build() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Icmpv6PacketBuilder::new(src, dst) + .icmpv6_type(Icmpv6Type::EchoRequest) + .icmpv6_code(icmpv6::echo_request::Icmpv6Codes::NoCode) + .echo_fields(0x1234, 0x1) + .payload(Bytes::from_static(b"hello")) + .culculate_checksum() + .build() + .to_bytes(), + _ => panic!("Source and destination IP version mismatch"), }; - // Create a channel to send/receive packet - let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { - Ok(Ethernet(tx, rx)) => (tx, rx), - Ok(_) => panic!("parse_frame: unhandled channel type"), - Err(e) => panic!("parse_frame: unable to create channel: {}", e), + let ip_packet = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => Ipv4PacketBuilder::new() + .source(src) + .destination(dst) + .protocol(IpNextProtocol::Icmp) + .flags(Ipv4Flags::DontFragment) + .payload(icmp_packet) + .build() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Ipv6PacketBuilder::new() + .source(src) + .destination(dst) + .next_header(IpNextProtocol::Icmpv6) + .payload(icmp_packet) + .build() + .to_bytes(), + _ => unreachable!(), }; - // Packet builder for ICMP Echo Request - let mut packet_builder = PacketBuilder::new(); - let ethernet_packet_builder = EthernetPacketBuilder { - src_mac: if use_tun { + let ethernet_packet = EthernetPacketBuilder::new() + .source(if use_tun { MacAddr::zero() } else { interface.mac_addr.clone().unwrap() - }, - dst_mac: if use_tun { + }) + .destination(if use_tun { MacAddr::zero() } else { interface.gateway.clone().unwrap().mac_addr - }, - ether_type: match dst_ip { + }) + .ethertype(match target_ip { IpAddr::V4(_) => EtherType::Ipv4, IpAddr::V6(_) => EtherType::Ipv6, - }, - }; - packet_builder.set_ethernet(ethernet_packet_builder); - - match dst_ip { - IpAddr::V4(dst_ipv4) => match src_ip { - IpAddr::V4(src_ipv4) => { - let ipv4_packet_builder = - Ipv4PacketBuilder::new(src_ipv4, dst_ipv4, IpNextLevelProtocol::Icmp); - packet_builder.set_ipv4(ipv4_packet_builder); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(dst_ipv6) => match src_ip { - IpAddr::V4(_) => {} - IpAddr::V6(src_ipv4) => { - let ipv6_packet_builder = - Ipv6PacketBuilder::new(src_ipv4, dst_ipv6, IpNextLevelProtocol::Icmpv6); - packet_builder.set_ipv6(ipv6_packet_builder); - } - }, - } + }) + .payload(ip_packet) + .build(); - match dst_ip { - IpAddr::V4(dst_ipv4) => match src_ip { - IpAddr::V4(src_ipv4) => { - let mut icmp_packet_builder = IcmpPacketBuilder::new(src_ipv4, dst_ipv4); - icmp_packet_builder.icmp_type = IcmpType::EchoRequest; - packet_builder.set_icmp(icmp_packet_builder); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(dst_ipv6) => match src_ip { - IpAddr::V4(_) => {} - IpAddr::V6(src_ipv6) => { - let mut icmpv6_packet_builder = Icmpv6PacketBuilder::new(src_ipv6, dst_ipv6); - icmpv6_packet_builder.icmpv6_type = Icmpv6Type::EchoRequest; - packet_builder.set_icmpv6(icmpv6_packet_builder); - } - }, - } - - // Send ICMP Echo Request packets - let packet: Vec = if use_tun { - packet_builder.ip_packet() + let packet = if use_tun { + ethernet_packet.ip_packet().unwrap() } else { - packet_builder.packet() + ethernet_packet.to_bytes() }; + match tx.send(&packet) { Some(_) => println!("Packet sent"), None => println!("Failed to send packet"), } - // Receive ICMP Echo Reply packets - println!("Waiting for ICMP Echo Reply packets..."); + println!("Waiting for ICMP Echo Reply..."); loop { match rx.next() { Ok(packet) => { - let mut parse_option: ParseOption = ParseOption::default(); + let mut parse_option = ParseOption::default(); if interface.is_tun() { - let payload_offset = if interface.is_loopback() { 14 } else { 0 }; parse_option.from_ip_packet = true; - parse_option.offset = payload_offset; + parse_option.offset = if interface.is_loopback() { 14 } else { 0 }; } - let frame: Frame = Frame::from_bytes(&packet, parse_option); + let frame = Frame::from_buf(&packet, parse_option).unwrap(); + if let Some(ip_layer) = &frame.ip { - if let Some(icmp_packet) = &ip_layer.icmp { - if icmp_packet.icmp_type == IcmpType::EchoReply { + if let Some(icmp) = &ip_layer.icmp { + if icmp.icmp_type == IcmpType::EchoReply { println!( - "Received ICMP Echo Reply packet from {}", + "Received ICMP Echo Reply from {}", ip_layer.ipv4.as_ref().unwrap().source ); println!( @@ -183,14 +155,14 @@ fn main() { interface.name, packet.len() ); - println!("Packet Frame: {:?}", frame); + println!("Frame: {:?}", frame); break; } } - if let Some(icmpv6_packet) = &ip_layer.icmpv6 { - if icmpv6_packet.icmpv6_type == Icmpv6Type::EchoReply { + if let Some(icmpv6) = &ip_layer.icmpv6 { + if icmpv6.icmpv6_type == Icmpv6Type::EchoReply { println!( - "Received ICMPv6 Echo Reply packet from {}", + "Received ICMPv6 Echo Reply from {}", ip_layer.ipv6.as_ref().unwrap().source ); println!( @@ -198,13 +170,13 @@ fn main() { interface.name, packet.len() ); - println!("Packet Frame: {:?}", frame); + println!("Frame: {:?}", frame); break; } } } } - Err(e) => println!("Failed to receive packet: {}", e), + Err(e) => eprintln!("Failed to receive: {}", e), } } } diff --git a/examples/icmp_socket.rs b/examples/icmp_socket.rs new file mode 100644 index 0000000..486e197 --- /dev/null +++ b/examples/icmp_socket.rs @@ -0,0 +1,59 @@ +//! Ping using IcmpSocket +//! +//! Usage: icmp_socket + +use bytes::Bytes; +use nex_socket::icmp::{IcmpConfig, IcmpKind, IcmpSocket}; +use nex_packet::builder::icmp::IcmpPacketBuilder; +use nex_packet::builder::icmpv6::Icmpv6PacketBuilder; +use nex_packet::{icmp, icmpv6}; +use nex::net::interface::{Interface, get_interfaces}; +use std::env; +use std::net::{IpAddr, SocketAddr}; + +fn main() -> std::io::Result<()> { + let target_ip: IpAddr = env::args().nth(1).expect("Missing target IP").parse().expect("parse ip"); + let interface = match env::args().nth(2) { + Some(name) => get_interfaces().into_iter().find(|i| i.name == name).expect("interface not found"), + None => Interface::default().expect("default interface"), + }; + + let src_ip = match target_ip { + IpAddr::V4(_) => interface.ipv4.get(0).map(|v| IpAddr::V4(v.addr())).expect("No IPv4 address"), + IpAddr::V6(_) => interface + .ipv6 + .iter() + .find(|v| nex::net::ip::is_global_ipv6(&v.addr())) + .map(|v| IpAddr::V6(v.addr())) + .expect("No global IPv6 address"), + }; + + let kind = if target_ip.is_ipv4() { IcmpKind::V4 } else { IcmpKind::V6 }; + let socket = IcmpSocket::new(&IcmpConfig::new(kind))?; + + let packet = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => IcmpPacketBuilder::new(src, dst) + .icmp_type(nex_packet::icmp::IcmpType::EchoRequest) + .icmp_code(icmp::echo_request::IcmpCodes::NoCode) + .echo_fields(0x1234, 1) + .payload(Bytes::from_static(b"hello")) + .culculate_checksum() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Icmpv6PacketBuilder::new(src, dst) + .icmpv6_type(nex_packet::icmpv6::Icmpv6Type::EchoRequest) + .icmpv6_code(icmpv6::echo_request::Icmpv6Codes::NoCode) + .echo_fields(0x1234, 1) + .payload(Bytes::from_static(b"hello")) + .culculate_checksum() + .to_bytes(), + _ => unreachable!(), + }; + + socket.send_to(&packet, SocketAddr::new(target_ip, 0))?; + println!("Sent echo request to {}", target_ip); + + let mut buf = [0u8; 1500]; + let (_n, from) = socket.recv_from(&mut buf)?; + println!("Received reply from {}", from.ip()); + Ok(()) +} diff --git a/examples/list_interfaces.rs b/examples/list_interfaces.rs deleted file mode 100644 index c1b3b52..0000000 --- a/examples/list_interfaces.rs +++ /dev/null @@ -1,47 +0,0 @@ -//! This example shows all interfaces and their properties. -//! -//! If you want to focus on network interfaces, -//! you can use the netdev -//! https://github.com/shellrow/netdev - -fn main() { - let interfaces = nex::net::interface::get_interfaces(); - for interface in interfaces { - println!("Interface:"); - println!("\tIndex: {}", interface.index); - println!("\tName: {}", interface.name); - println!("\tFriendly Name: {:?}", interface.friendly_name); - println!("\tDescription: {:?}", interface.description); - println!("\tType: {}", interface.if_type.name()); - println!("\tFlags: {:?}", interface.flags); - println!("\t\tis UP {}", interface.is_up()); - println!("\t\tis LOOPBACK {}", interface.is_loopback()); - println!("\t\tis MULTICAST {}", interface.is_multicast()); - println!("\t\tis BROADCAST {}", interface.is_broadcast()); - println!("\t\tis POINT TO POINT {}", interface.is_point_to_point()); - println!("\t\tis TUN {}", interface.is_tun()); - println!("\t\tis RUNNING {}", interface.is_running()); - println!("\t\tis PHYSICAL {}", interface.is_physical()); - if let Some(mac_addr) = interface.mac_addr { - println!("\tMAC Address: {}", mac_addr); - } else { - println!("\tMAC Address: (Failed to get mac address)"); - } - println!("\tIPv4: {:?}", interface.ipv4); - - // Print the IPv6 addresses with the scope ID after them as a suffix - let ipv6_strs: Vec = interface - .ipv6 - .iter() - .zip(interface.ipv6_scope_ids) - .map(|(ipv6, scope_id)| format!("{:?}%{}", ipv6, scope_id)) - .collect(); - println!("\tIPv6: [{}]", ipv6_strs.join(", ")); - - println!("\tTransmit Speed: {:?}", interface.transmit_speed); - println!("\tReceive Speed: {:?}", interface.receive_speed); - println!("MTU: {:?}", interface.mtu); - println!("Default: {}", interface.default); - println!(); - } -} diff --git a/examples/ndp.rs b/examples/ndp.rs index 7d47afe..30583a6 100644 --- a/examples/ndp.rs +++ b/examples/ndp.rs @@ -1,154 +1,140 @@ -//! This example sends NDP packet to the target and waits for NDP NeighborAdvertisement packets. +//! Sends NDP Neighbor Solicitation and waits for Neighbor Advertisement. //! -//! e.g. +//! Usage: +//! ndp //! -//! ndp "fe80::6284:bdff:fe95:ca80" eth0 +//! Example: +//! ndp "fe80::6284:bdff:fe95:ca80" eth0 use nex::datalink; use nex::datalink::Channel::Ethernet; -use nex::net::interface::Interface; +use nex::net::interface::{get_interfaces, Interface}; use nex::net::mac::MacAddr; use nex::packet::ethernet::EtherType; -use nex::packet::ethernet::MAC_ADDR_LEN; -use nex::packet::frame::Frame; -use nex::packet::frame::ParseOption; -use nex::packet::icmpv6::ndp::{NDP_OPT_PACKET_LEN, NDP_SOL_PACKET_LEN}; +use nex::packet::frame::{Frame, ParseOption}; use nex::packet::icmpv6::Icmpv6Type; -use nex::packet::ip::IpNextLevelProtocol; -use nex::util::packet_builder::builder::PacketBuilder; -use nex::util::packet_builder::ethernet::EthernetPacketBuilder; -use nex::util::packet_builder::ipv6::Ipv6PacketBuilder; -use nex::util::packet_builder::ndp::NdpPacketBuilder; +use nex::packet::ip::IpNextProtocol; +use nex::packet::builder::ethernet::EthernetPacketBuilder; +use nex::packet::builder::ipv6::Ipv6PacketBuilder; +use nex_packet::packet::Packet; +use nex_packet::builder::ndp::NdpPacketBuilder; use std::env; -use std::net::IpAddr; -use std::net::Ipv6Addr; +use std::net::{IpAddr, Ipv6Addr}; use std::process; -const USAGE: &str = "USAGE: ndp "; +/// Compute multicast MAC address from solicited-node multicast IPv6 address +fn ipv6_multicast_mac(ipv6: &Ipv6Addr) -> MacAddr { + let segments = ipv6.segments(); + MacAddr::new( + 0x33, + 0x33, + ((segments[6] >> 8) & 0xff) as u8, + (segments[6] & 0xff) as u8, + ((segments[7] >> 8) & 0xff) as u8, + (segments[7] & 0xff) as u8, + ) +} fn main() { - let interface: Interface = match env::args().nth(2) { - Some(n) => { - // Use interface specified by the user - let interfaces: Vec = nex::net::interface::get_interfaces(); - let interface: Interface = interfaces - .into_iter() - .find(|interface| interface.name == n) - .expect("Failed to get interface information"); - interface - } - None => { - // Use the default interface - match Interface::default() { - Ok(interface) => interface, - Err(e) => { - println!("Failed to get default interface: {}", e); - process::exit(1); - } - } - } - }; - let dst_ip: Ipv6Addr = match env::args().nth(1) { - Some(target_ip) => match target_ip.parse::() { - Ok(ip) => match ip { - IpAddr::V4(_) => { - println!("IPv4 is not supported"); - eprintln!("{USAGE}"); - process::exit(1); - } - IpAddr::V6(ipv6) => ipv6, - }, - Err(e) => { - println!("Failed to parse target ip: {}", e); - eprintln!("{USAGE}"); - process::exit(1); - } - }, - None => { - println!("Failed to get target ip"); - eprintln!("{USAGE}"); + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: ndp "); + process::exit(1); + } + + let target_ip: Ipv6Addr = match args[1].parse() { + Ok(IpAddr::V6(addr)) => addr, + _ => { + eprintln!("Please provide a valid IPv6 address"); process::exit(1); } }; + let interface = match env::args().nth(2) { + Some(name) => get_interfaces() + .into_iter() + .find(|i| i.name == name) + .expect("Failed to get interface"), + None => Interface::default().expect("Failed to get default interface"), + }; + let src_ip: Ipv6Addr = interface.ipv6[0].addr(); - // Create a channel to send/receive packet + let src_mac = interface.mac_addr.expect("No MAC address on interface"); + let dst_mac = ipv6_multicast_mac(&target_ip); + let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { Ok(Ethernet(tx, rx)) => (tx, rx), - Ok(_) => panic!("parse_frame: unhandled channel type"), - Err(e) => panic!("parse_frame: unable to create channel: {}", e), + Ok(_) => panic!("Unsupported channel type"), + Err(e) => panic!("Failed to create datalink channel: {}", e), }; - // Packet builder for ICMP Echo Request - let mut packet_builder = PacketBuilder::new(); - let ethernet_packet_builder = EthernetPacketBuilder { - src_mac: interface.mac_addr.clone().unwrap(), - dst_mac: MacAddr::broadcast(), - ether_type: EtherType::Ipv6, - }; - packet_builder.set_ethernet(ethernet_packet_builder); - - let mut ipv6_packet_builder = - Ipv6PacketBuilder::new(src_ip, dst_ip, IpNextLevelProtocol::Icmpv6); - ipv6_packet_builder.payload_length = - Some((NDP_SOL_PACKET_LEN + NDP_OPT_PACKET_LEN + MAC_ADDR_LEN) as u16); - ipv6_packet_builder.hop_limit = Some(u8::MAX); - packet_builder.set_ipv6(ipv6_packet_builder); - - let ndp_packet_builder = - NdpPacketBuilder::new(interface.mac_addr.clone().unwrap(), src_ip, dst_ip); - packet_builder.set_ndp(ndp_packet_builder); - - // Send NDP NeighborSolicitation packets - match tx.send(&packet_builder.packet()) { - Some(_) => println!("NDP Packet sent"), - None => println!("Failed to send packet"), + // Build NDP packet + //let ndp_payload_len = (NDP_SOL_PACKET_LEN + NDP_OPT_PACKET_LEN + MAC_ADDR_LEN) as u16; + + let ipv6 = Ipv6PacketBuilder::new() + .source(src_ip) + .destination(target_ip) + .next_header(IpNextProtocol::Icmpv6) + .hop_limit(255); + + let ndp = NdpPacketBuilder::new(src_mac, src_ip, target_ip); + + let ethernet = EthernetPacketBuilder::new() + .source(src_mac) + .destination(dst_mac) + .ethertype(EtherType::Ipv6) + .payload(ipv6.payload(ndp.build().to_bytes()).build().to_bytes()); + + // Send NDP Neighbor Solicitation + let packet = ethernet.build().to_bytes(); + + if tx.send(&packet).is_some() { + println!("NDP Neighbor Solicitation sent to {}", target_ip); + } else { + eprintln!("Failed to send NDP packet"); + return; } - // Receive NDP Neighbor Advertisement packets - println!("Waiting for NDP Neighbor Advertisement packets..."); + println!("Waiting for Neighbor Advertisement..."); + loop { match rx.next() { Ok(packet) => { - let mut parse_option: ParseOption = ParseOption::default(); + let mut parse_option = ParseOption::default(); if interface.is_tun() { - let payload_offset = if interface.is_loopback() { 14 } else { 0 }; parse_option.from_ip_packet = true; - parse_option.offset = payload_offset; + parse_option.offset = if interface.is_loopback() { 14 } else { 0 }; } - let frame: Frame = Frame::from_bytes(&packet, parse_option); - if let Some(ip_layer) = &frame.ip { - if let Some(icmpv6_packet) = &ip_layer.icmpv6 { - if icmpv6_packet.icmpv6_type == Icmpv6Type::NeighborAdvertisement { - println!( - "Received NDP Neighbor Advertisement packet from {}", - ip_layer.ipv6.as_ref().unwrap().source - ); - println!( - "MAC address: {}", - frame - .datalink - .as_ref() - .unwrap() - .ethernet - .as_ref() - .unwrap() - .source - .address() - ); - println!( - "---- Interface: {}, Total Length: {} bytes ----", - interface.name, - packet.len() - ); - println!("Packet Frame: {:?}", frame); - break; + + if let Some(frame) = Frame::from_buf(&packet, parse_option) { + if let Some(ip_layer) = &frame.ip { + if let Some(icmpv6) = &ip_layer.icmpv6 { + if icmpv6.icmpv6_type == Icmpv6Type::NeighborAdvertisement { + if let Some(ipv6_hdr) = &ip_layer.ipv6 { + println!( + "Received Neighbor Advertisement from {}", + ipv6_hdr.source + ); + if let Some(dlink) = &frame.datalink { + if let Some(eth) = &dlink.ethernet { + println!("MAC address: {}", eth.source.address()); + } + } + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Frame: {:?}", frame); + break; + } + } } } } } - Err(e) => println!("Failed to receive packet: {}", e), + Err(e) => eprintln!("Receive failed: {}", e), } } } diff --git a/examples/parse_frame.rs b/examples/parse_frame.rs index 75d3782..80f7194 100644 --- a/examples/parse_frame.rs +++ b/examples/parse_frame.rs @@ -64,10 +64,70 @@ fn main() { parse_option.from_ip_packet = true; parse_option.offset = payload_offset; } - let frame: Frame = Frame::from_bytes(&packet, parse_option); - println!("Packet Frame: {:?}", frame); + match Frame::from_buf(&packet, parse_option) { + Some(frame) => { + display_frame(&frame); + } + None => { + println!("Failed to parse packet as Frame"); + } + } } Err(e) => panic!("parse_frame: unable to receive packet: {}", e), } } } + +pub fn display_frame(frame: &Frame) { + println!("Packet Frame ({} bytes)", frame.packet_len); + + if let Some(dl) = &frame.datalink { + if let Some(eth) = &dl.ethernet { + println!(" Ethernet: {} > {} ({:?})", eth.source, eth.destination, eth.ethertype); + } + if let Some(arp) = &dl.arp { + println!( + " ARP: {}({}) > {}({}); operation: {:?}", + arp.sender_hw_addr, + arp.sender_proto_addr, + arp.target_hw_addr, + arp.target_proto_addr, + arp.operation + ); + } + } + + if let Some(ip) = &frame.ip { + if let Some(ipv4) = &ip.ipv4 { + println!( + " IPv4: {} -> {} (protocol: {:?})", + ipv4.source, ipv4.destination, ipv4.next_level_protocol + ); + } + if let Some(ipv6) = &ip.ipv6 { + println!( + " IPv6: {} -> {} (next header: {:?})", + ipv6.source, ipv6.destination, ipv6.next_header + ); + } + if ip.icmp.is_some() { + println!(" ICMP: present"); + } + if ip.icmpv6.is_some() { + println!(" ICMPv6: present"); + } + } + + if let Some(tp) = &frame.transport { + if let Some(tcp) = &tp.tcp { + println!(" TCP: {} -> {}", tcp.source, tcp.destination); + } + if let Some(udp) = &tp.udp { + println!(" UDP: {} -> {}", udp.source, udp.destination); + } + } + + if !frame.payload.is_empty() { + println!(" Payload: {} bytes", frame.payload.len()); + } +} diff --git a/examples/serialize.rs b/examples/serialize.rs deleted file mode 100644 index 8542678..0000000 --- a/examples/serialize.rs +++ /dev/null @@ -1,77 +0,0 @@ -//! Basic packet capture using nex -//! -//! Parse packet as Frame and print it as JSON format - -use nex::datalink; -use nex::net::interface::Interface; -use nex::packet::frame::Frame; -use nex::packet::frame::ParseOption; -use std::env; -use std::process; - -fn main() { - use nex::datalink::Channel::Ethernet; - let interface: Interface = match env::args().nth(1) { - Some(n) => { - // Use interface specified by user - let interfaces: Vec = nex::net::interface::get_interfaces(); - let interface: Interface = interfaces - .into_iter() - .find(|interface| interface.name == n) - .expect("Failed to get interface information"); - interface - } - None => { - // Use default interface - match Interface::default() { - Ok(interface) => interface, - Err(e) => { - println!("Failed to get default interface: {}", e); - process::exit(1); - } - } - } - }; - - // Create a channel to receive packet - let (mut _tx, mut rx) = match datalink::channel(&interface, Default::default()) { - Ok(Ethernet(tx, rx)) => (tx, rx), - Ok(_) => panic!("parse_frame: unhandled channel type"), - Err(e) => panic!("parse_frame: unable to create channel: {}", e), - }; - let mut capture_no: usize = 0; - loop { - match rx.next() { - Ok(packet) => { - capture_no += 1; - println!( - "---- Interface: {}, No.: {}, Total Length: {} bytes ----", - interface.name, - capture_no, - packet.len() - ); - let mut parse_option: ParseOption = ParseOption::default(); - if interface.is_tun() { - let payload_offset; - if interface.is_loopback() { - payload_offset = 14; - } else { - payload_offset = 0; - } - parse_option.from_ip_packet = true; - parse_option.offset = payload_offset; - } - let frame: Frame = Frame::from_bytes(&packet, parse_option); - match serde_json::to_string(&frame) { - Ok(json) => { - println!("{}", json); - } - Err(e) => { - println!("Serialization Error: {}", e); - } - } - } - Err(e) => panic!("parse_frame: unable to receive packet: {}", e), - } - } -} diff --git a/examples/tcp_ping.rs b/examples/tcp_ping.rs index 5a65597..67d4c24 100644 --- a/examples/tcp_ping.rs +++ b/examples/tcp_ping.rs @@ -1,24 +1,29 @@ -//! This example sends TCP SYN packet to the target socket and waits for TCP SYN+ACK or RST+ACK packet. +//! Sends TCP SYN packet to the target socket and waits for TCP SYN+ACK or RST+ACK packet. //! -//! e.g. +//! Usage: +//! tcp_ping +//! +//! Example: //! //! IPv4: tcp_ping 1.1.1.1:80 eth0 //! //! IPv6: tcp_ping "[2606:4700:4700::1111]:80" eth0 +use bytes::Bytes; use nex::datalink; use nex::datalink::Channel::Ethernet; use nex::net::interface::Interface; use nex::net::mac::MacAddr; use nex::packet::ethernet::EtherType; use nex::packet::frame::{Frame, ParseOption}; -use nex::packet::ip::IpNextLevelProtocol; -use nex::packet::tcp::{TcpFlags, TcpOption}; -use nex::util::packet_builder::builder::PacketBuilder; -use nex::util::packet_builder::ethernet::EthernetPacketBuilder; -use nex::util::packet_builder::ipv4::Ipv4PacketBuilder; -use nex::util::packet_builder::ipv6::Ipv6PacketBuilder; -use nex::util::packet_builder::tcp::TcpPacketBuilder; +use nex::packet::ip::IpNextProtocol; +use nex::packet::tcp::{TcpFlags, TcpOptionPacket}; +use nex::packet::builder::ethernet::EthernetPacketBuilder; +use nex::packet::builder::ipv4::Ipv4PacketBuilder; +use nex::packet::builder::ipv6::Ipv6PacketBuilder; +use nex::packet::builder::tcp::TcpPacketBuilder; +use nex_packet::ipv4::Ipv4Flags; +use nex_packet::packet::Packet; use std::env; use std::net::{IpAddr, SocketAddr}; use std::process; @@ -71,49 +76,27 @@ fn main() { Err(e) => panic!("Failed to create channel: {}", e), }; - // Packet builder for TCP SYN - let mut packet_builder = PacketBuilder::new(); - let ethernet_packet_builder = EthernetPacketBuilder { - src_mac: if use_tun { - MacAddr::zero() - } else { - interface.mac_addr.clone().unwrap() - }, - dst_mac: if use_tun { - MacAddr::zero() - } else { - interface.gateway.clone().unwrap().mac_addr - }, - ether_type: match target_socket.ip() { - IpAddr::V4(_) => EtherType::Ipv4, - IpAddr::V6(_) => EtherType::Ipv6, - }, - }; - packet_builder.set_ethernet(ethernet_packet_builder); - - match target_socket.ip() { - IpAddr::V4(dst_ipv4) => match interface.ipv4.get(0) { - Some(src_ipv4) => { - let ipv4_packet_builder = - Ipv4PacketBuilder::new(src_ipv4.addr(), dst_ipv4, IpNextLevelProtocol::Tcp); - packet_builder.set_ipv4(ipv4_packet_builder); - } - None => { - println!("No IPv4 address on the interface"); - process::exit(1); + let dst_ip = target_socket.ip(); + let src_ip: IpAddr; + match dst_ip { + IpAddr::V4(_) => { + // For IPv4, use the first IPv4 address of the interface + match interface.ipv4.get(0) { + Some(ipv4) => src_ip = IpAddr::V4(ipv4.addr()), + None => { + println!("No IPv4 address on the interface"); + process::exit(1); + } } - }, - IpAddr::V6(dst_ipv6) => { + } + IpAddr::V6(_) => { + // For IPv6, use the first global IPv6 address of the interface match interface .ipv6 .iter() .find(|ipv6| nex::net::ip::is_global_ipv6(&ipv6.addr())) { - Some(src_ipv6) => { - let ipv6_packet_builder = - Ipv6PacketBuilder::new(src_ipv6.addr(), dst_ipv6, IpNextLevelProtocol::Tcp); - packet_builder.set_ipv6(ipv6_packet_builder); - } + Some(ipv6) => src_ip = IpAddr::V6(ipv6.addr()), None => { println!("No global IPv6 address on the interface"); process::exit(1); @@ -122,62 +105,87 @@ fn main() { } } - match target_socket.ip() { - IpAddr::V4(_dst_ipv4) => match interface.ipv4.get(0) { - Some(src_ipv4) => { - let mut tcp_packet_builder = TcpPacketBuilder::new( - SocketAddr::new(IpAddr::V4(src_ipv4.addr()), 53443), - target_socket, - ); - tcp_packet_builder.flags = TcpFlags::SYN; - tcp_packet_builder.options = vec![ - TcpOption::mss(1460), - TcpOption::sack_perm(), - TcpOption::nop(), - TcpOption::nop(), - TcpOption::wscale(7), - ]; - packet_builder.set_tcp(tcp_packet_builder); - } - None => { - println!("No IPv4 address on the interface"); - process::exit(1); + // Packet builder for TCP SYN + let tcp_packet = TcpPacketBuilder::new() + .source(53443) + .destination(target_socket.port()) + .flags(TcpFlags::SYN) + .window(64240) + .options(vec![ + TcpOptionPacket::mss(1460), + TcpOptionPacket::sack_perm(), + TcpOptionPacket::nop(), + TcpOptionPacket::nop(), + TcpOptionPacket::wscale(7), + ]) + .culculate_checksum(&src_ip, &dst_ip) + .build(); + + let ip_packet: Bytes; + match dst_ip { + IpAddr::V4(dst_ipv4) => { + match src_ip { + IpAddr::V4(src_ipv4) => { + // Use the source IPv4 address + let ipv4_packet = Ipv4PacketBuilder::new() + .source(src_ipv4) + .destination(dst_ipv4) + .protocol(IpNextProtocol::Tcp) + .flags(Ipv4Flags::DontFragment) + .payload(tcp_packet.to_bytes()) + .build(); + ip_packet = ipv4_packet.to_bytes(); + } + IpAddr::V6(_) => { + println!("Source IP must be IPv4 for IPv4 destination"); + process::exit(1); + } } + }, - IpAddr::V6(_dst_ipv6) => { - match interface - .ipv6 - .iter() - .find(|ipv6| nex::net::ip::is_global_ipv6(&ipv6.addr())) - { - Some(src_ipv6) => { - let mut tcp_packet_builder = TcpPacketBuilder::new( - SocketAddr::new(IpAddr::V6(src_ipv6.addr()), 53443), - target_socket, - ); - tcp_packet_builder.flags = TcpFlags::SYN; - tcp_packet_builder.options = vec![ - TcpOption::mss(1460), - TcpOption::sack_perm(), - TcpOption::nop(), - TcpOption::nop(), - TcpOption::wscale(7), - ]; - packet_builder.set_tcp(tcp_packet_builder); - } - None => { - println!("No global IPv6 address on the interface"); + IpAddr::V6(dst_ipv6) => { + match src_ip { + IpAddr::V4(_) => { + println!("Source IP must be IPv6 for IPv6 destination"); process::exit(1); } + IpAddr::V6(src_ipv6) => { + // Use the source IPv6 address + let ipv6_packet = Ipv6PacketBuilder::new() + .source(src_ipv6) + .destination(dst_ipv6) + .next_header(IpNextProtocol::Tcp) + .payload(tcp_packet.to_bytes()) + .build(); + ip_packet = ipv6_packet.to_bytes(); + } } } } + let ethernet_packet = EthernetPacketBuilder::new() + .source(if use_tun { + MacAddr::zero() + } else { + interface.mac_addr.clone().unwrap() + }) + .destination(if use_tun { + MacAddr::zero() + } else { + interface.gateway.clone().unwrap().mac_addr + }) + .ethertype(match target_socket.ip() { + IpAddr::V4(_) => EtherType::Ipv4, + IpAddr::V6(_) => EtherType::Ipv6, + }) + .payload(ip_packet) + .build(); + // Send TCP SYN packets - let packet: Vec = if use_tun { - packet_builder.ip_packet() + let packet: Bytes = if use_tun { + ethernet_packet.ip_packet().unwrap() } else { - packet_builder.packet() + ethernet_packet.to_bytes() }; match tx.send(&packet) { Some(_) => println!("Packet sent"), @@ -195,7 +203,7 @@ fn main() { parse_option.from_ip_packet = true; parse_option.offset = payload_offset; } - let frame: Frame = Frame::from_bytes(&packet, parse_option); + let frame: Frame = Frame::from_buf(&packet, parse_option).unwrap(); // Check each layer. If the packet is TCP SYN+ACK or RST+ACK, print it out if let Some(ip_layer) = &frame.ip { if let Some(transport_layer) = &frame.transport { diff --git a/examples/tcp_socket.rs b/examples/tcp_socket.rs new file mode 100644 index 0000000..d729610 --- /dev/null +++ b/examples/tcp_socket.rs @@ -0,0 +1,29 @@ +//! Simple TCP connect using TcpSocket +//! +//! Usage: tcp_socket + +use nex_socket::tcp::TcpSocket; +use std::env; +use std::io::{Read, Write}; +use std::net::{IpAddr, SocketAddr}; + +fn main() -> std::io::Result<()> { + let ip: IpAddr = env::args().nth(1).expect("IP").parse().expect("ip"); + let port: u16 = env::args().nth(2).unwrap_or_else(|| "80".into()).parse().expect("port"); + let addr = SocketAddr::new(ip, port); + + let socket = match addr { + SocketAddr::V4(_) => TcpSocket::v4_stream()?, + SocketAddr::V6(_) => TcpSocket::v6_stream()?, + }; + socket.connect(addr)?; + let mut stream = socket.to_tcp_stream()?; + + let req = format!("GET / HTTP/1.1\r\nHost: {}\r\n\r\n", ip); + stream.write_all(req.as_bytes())?; + + let mut buf = [0u8; 512]; + let n = stream.read(&mut buf)?; + println!("Received {} bytes:\n{}", n, String::from_utf8_lossy(&buf[..n])); + Ok(()) +} diff --git a/examples/tcp_stream.rs b/examples/tcp_stream.rs deleted file mode 100644 index 2667639..0000000 --- a/examples/tcp_stream.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::net::{IpAddr, Ipv4Addr, Shutdown, SocketAddr}; - -use nex_packet::ip::IpNextLevelProtocol; -use nex_socket::{IpVersion, Socket, SocketOption, SocketType}; - -fn main() { - let socket_option = SocketOption { - ip_version: IpVersion::V4, - socket_type: SocketType::Stream, - protocol: Some(IpNextLevelProtocol::Tcp), - non_blocking: false, - }; - let socket = Socket::new(socket_option).unwrap(); - println!("Socket created"); - println!("Connecting to 1.1.1.1:80 ..."); - let ip_addr: IpAddr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); - match socket.connect(&SocketAddr::new(ip_addr, 80)) { - Ok(_) => { - println!("Connected to 1.1.1.1:80"); - let req = format!("GET / HTTP/1.1\r\nHost: {}\r\n\r\n", ip_addr.to_string()); - println!("Sending data (HTTP Request) ..."); - match socket.write(req.as_bytes()) { - Ok(n) => println!("{} bytes sent (payload)", n), - Err(e) => println!("{}", e), - } - let mut res = vec![0; 1024]; - println!("Receiving data ..."); - - match socket.read(&mut res) { - Ok(n) => { - println!("{} bytes received (HTTP Response):", n); - println!("----------------------------------------"); - println!("{}", String::from_utf8_lossy(&res[..n])); - println!("----------------------------------------"); - } - Err(e) => println!("{}", e), - } - - println!("Closing socket ..."); - match socket.shutdown(Shutdown::Both) { - Ok(_) => println!("Socket closed"), - Err(e) => println!("{}", e), - } - } - Err(e) => println!("{}", e), - } -} diff --git a/examples/udp_ping.rs b/examples/udp_ping.rs index 2c52d50..27f75cb 100644 --- a/examples/udp_ping.rs +++ b/examples/udp_ping.rs @@ -5,8 +5,9 @@ //! //! Example: //! IPv4: udp_ping 1.1.1.1 eth0 -//! IPv6: udp_ping 2606:4700:4700::1111 eth0 +//! IPv6: udp_ping "2606:4700:4700::1111" eth0 +use bytes::Bytes; use nex::datalink; use nex::datalink::Channel::Ethernet; use nex::net::interface::Interface; @@ -15,208 +16,147 @@ use nex::packet::ethernet::EtherType; use nex::packet::frame::{Frame, ParseOption}; use nex::packet::icmp::IcmpType; use nex::packet::icmpv6::Icmpv6Type; -use nex::packet::ip::IpNextLevelProtocol; -use nex::util::packet_builder::builder::PacketBuilder; -use nex::util::packet_builder::ethernet::EthernetPacketBuilder; -use nex::util::packet_builder::ipv4::Ipv4PacketBuilder; -use nex::util::packet_builder::ipv6::Ipv6PacketBuilder; -use nex::util::packet_builder::udp::UdpPacketBuilder; +use nex::packet::builder::ethernet::EthernetPacketBuilder; +use nex::packet::builder::ipv4::Ipv4PacketBuilder; +use nex::packet::builder::ipv6::Ipv6PacketBuilder; +use nex::packet::builder::udp::UdpPacketBuilder; +use nex_packet::ip::IpNextProtocol; +use nex_packet::ipv4::Ipv4Flags; +use nex_packet::packet::Packet; use std::env; -use std::net::{IpAddr, SocketAddr}; -use std::process; - -const USAGE: &str = "USAGE: udp_ping "; +use std::net::IpAddr; const SRC_PORT: u16 = 53443; const DST_PORT: u16 = 33435; fn main() { let interface: Interface = match env::args().nth(2) { - Some(n) => { - // Use the interface specified by the user - let interfaces: Vec = nex::net::interface::get_interfaces(); - let interface = interfaces - .into_iter() - .find(|interface| interface.name == n) - .expect("Failed to get interface information"); - interface - } - None => { - // Use the default interface - match Interface::default() { - Ok(interface) => interface, - Err(e) => { - println!("Failed to get the default interface: {}", e); - process::exit(1); - } - } - } - }; - let use_tun: bool = interface.is_tun(); - let target_ip: IpAddr = match env::args().nth(1) { - Some(target_ip_str) => match target_ip_str.parse() { - Ok(ip) => ip, - Err(e) => { - println!("Failed to parse the target IP: {}", e); - eprintln!("{USAGE}"); - process::exit(1); - } - }, - None => { - println!("Failed to get the target IP"); - eprintln!("{USAGE}"); - process::exit(1); - } + Some(n) => nex::net::interface::get_interfaces() + .into_iter() + .find(|i| i.name == n) + .expect("Failed to get interface information"), + None => Interface::default().expect("Failed to get default interface"), }; + let use_tun = interface.is_tun(); + + let target_ip: IpAddr = env::args() + .nth(1) + .expect("Missing target IP") + .parse() + .expect("Failed to parse target IP"); - // Create a new channel let (mut tx, mut rx) = match datalink::channel(&interface, Default::default()) { Ok(Ethernet(tx, rx)) => (tx, rx), Ok(_) => panic!("Unhandled channel type"), - Err(e) => panic!("Failed to create a channel: {}", e), + Err(e) => panic!("Failed to create channel: {}", e), }; - // Packet builder for UDP Ping - let mut packet_builder = PacketBuilder::new(); - let ethernet_packet_builder = EthernetPacketBuilder { - src_mac: if use_tun { + let src_ip: IpAddr = match target_ip { + IpAddr::V4(_) => interface + .ipv4 + .get(0) + .map(|v| IpAddr::V4(v.addr())) + .expect("No IPv4 address on interface"), + IpAddr::V6(_) => interface + .ipv6 + .iter() + .find(|v| nex::net::ip::is_global_ipv6(&v.addr())) + .map(|v| IpAddr::V6(v.addr())) + .expect("No global IPv6 address on interface"), + }; + + let udp_packet = UdpPacketBuilder::new() + .source(SRC_PORT) + .destination(DST_PORT) + .culculate_checksum(&src_ip, &target_ip) + .build(); + + let ip_packet: Bytes = match (src_ip, target_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) => Ipv4PacketBuilder::new() + .source(src) + .destination(dst) + .protocol(IpNextProtocol::Udp) + .flags(Ipv4Flags::DontFragment) + .payload(udp_packet.to_bytes()) + .build() + .to_bytes(), + (IpAddr::V6(src), IpAddr::V6(dst)) => Ipv6PacketBuilder::new() + .source(src) + .destination(dst) + .next_header(IpNextProtocol::Udp) + .payload(udp_packet.to_bytes()) + .build() + .to_bytes(), + _ => panic!("Source and destination IP version mismatch"), + }; + + let ethernet_packet = EthernetPacketBuilder::new() + .source(if use_tun { MacAddr::zero() } else { interface.mac_addr.clone().unwrap() - }, - dst_mac: if use_tun { + }) + .destination(if use_tun { MacAddr::zero() } else { interface.gateway.clone().unwrap().mac_addr - }, - ether_type: match target_ip { + }) + .ethertype(match target_ip { IpAddr::V4(_) => EtherType::Ipv4, IpAddr::V6(_) => EtherType::Ipv6, - }, - }; - packet_builder.set_ethernet(ethernet_packet_builder); + }) + .payload(ip_packet) + .build(); - match target_ip { - IpAddr::V4(dst_ipv4) => match interface.ipv4.get(0) { - Some(src_ipv4) => { - let ipv4_packet_builder = - Ipv4PacketBuilder::new(src_ipv4.addr(), dst_ipv4, IpNextLevelProtocol::Udp); - packet_builder.set_ipv4(ipv4_packet_builder); - } - None => { - println!("No IPv4 address on the interface"); - process::exit(1); - } - }, - IpAddr::V6(dst_ipv6) => { - match interface - .ipv6 - .iter() - .find(|ipv6| nex::net::ip::is_global_ipv6(&ipv6.addr())) - { - Some(src_ipv6) => { - let ipv6_packet_builder = - Ipv6PacketBuilder::new(src_ipv6.addr(), dst_ipv6, IpNextLevelProtocol::Udp); - packet_builder.set_ipv6(ipv6_packet_builder); - } - None => { - println!("No global IPv6 address on the interface"); - process::exit(1); - } - } - } - } - - match target_ip { - IpAddr::V4(_dst_ipv4) => match interface.ipv4.get(0) { - Some(src_ipv4) => { - let udp_packet_builder = UdpPacketBuilder::new( - SocketAddr::new(IpAddr::V4(src_ipv4.addr()), SRC_PORT), - SocketAddr::new(target_ip, DST_PORT), - ); - packet_builder.set_udp(udp_packet_builder); - } - None => { - println!("No IPv4 address on the interface"); - process::exit(1); - } - }, - IpAddr::V6(_dst_ipv6) => { - match interface - .ipv6 - .iter() - .find(|ipv6| nex::net::ip::is_global_ipv6(&ipv6.addr())) - { - Some(src_ipv6) => { - let udp_packet_builder = UdpPacketBuilder::new( - SocketAddr::new(IpAddr::V6(src_ipv6.addr()), SRC_PORT), - SocketAddr::new(target_ip, DST_PORT), - ); - packet_builder.set_udp(udp_packet_builder); - } - None => { - println!("No global IPv6 address on the interface"); - process::exit(1); - } - } - } - } - - // Send UDP Ping packet - let packet: Vec = if use_tun { - packet_builder.ip_packet() + let packet: Bytes = if use_tun { + ethernet_packet.ip_packet().unwrap() } else { - packet_builder.packet() + ethernet_packet.to_bytes() }; match tx.send(&packet) { Some(_) => println!("UDP Ping packet sent"), None => println!("Failed to send UDP Ping packet"), } - // Receive ICMP Port Unreachable println!("Waiting for ICMP Port Unreachable..."); loop { match rx.next() { Ok(packet) => { - let mut parse_option: ParseOption = ParseOption::default(); + let mut parse_option = ParseOption::default(); if interface.is_tun() { - let payload_offset = if interface.is_loopback() { 14 } else { 0 }; parse_option.from_ip_packet = true; - parse_option.offset = payload_offset; + parse_option.offset = if interface.is_loopback() { 14 } else { 0 }; } - let frame: Frame = Frame::from_bytes(&packet, parse_option); - // Check each layer. If the packet is a ICMP Port Unreachable, print it out + let frame = Frame::from_buf(&packet, parse_option).unwrap(); + if let Some(ip_layer) = &frame.ip { - if let Some(icmp_header) = &ip_layer.icmp { - if icmp_header.icmp_type == IcmpType::DestinationUnreachable { - if let Some(ipv4) = &ip_layer.ipv4 { - println!("Received ICMP Port Unreachable from {}", ipv4.source); - println!( - "---- Interface: {}, Total Length: {} bytes ----", - interface.name, - packet.len() - ); - println!("Packet Frame: {:?}", frame); - break; - } + if let Some(icmp) = &ip_layer.icmp { + if icmp.icmp_type == IcmpType::DestinationUnreachable { + println!("Received ICMP Port Unreachable (v4) from {}", ip_layer.ipv4.as_ref().unwrap().source); + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Packet Frame: {:?}", frame); + break; } - } else if let Some(icmpv6_header) = &ip_layer.icmpv6 { - if icmpv6_header.icmpv6_type == Icmpv6Type::DestinationUnreachable { - if let Some(ipv6) = &ip_layer.ipv6 { - println!("Received ICMP Port Unreachable from {}", ipv6.source); - println!( - "---- Interface: {}, Total Length: {} bytes ----", - interface.name, - packet.len() - ); - println!("Packet Frame: {:?}", frame); - break; - } + } + if let Some(icmpv6) = &ip_layer.icmpv6 { + if icmpv6.icmpv6_type == Icmpv6Type::DestinationUnreachable { + println!("Received ICMP Port Unreachable (v6) from {}", ip_layer.ipv6.as_ref().unwrap().source); + println!( + "---- Interface: {}, Total Length: {} bytes ----", + interface.name, + packet.len() + ); + println!("Packet Frame: {:?}", frame); + break; } } } } - Err(e) => println!("Failed to receive packet: {}", e), + Err(e) => println!("Receive failed: {}", e), } } } diff --git a/examples/udp_socket.rs b/examples/udp_socket.rs new file mode 100644 index 0000000..b2424fb --- /dev/null +++ b/examples/udp_socket.rs @@ -0,0 +1,30 @@ +//! UDP echo using UdpSocket +//! +//! This example starts a small UDP echo server and client using nex-socket. + +use nex_socket::udp::{UdpConfig, UdpSocket}; +use std::thread; + +fn main() -> std::io::Result<()> { + let server_cfg = UdpConfig { bind_addr: Some("127.0.0.1:0".parse().unwrap()), ..Default::default() }; + let server = UdpSocket::from_config(&server_cfg)?; + let server_addr = server.local_addr()?; + + let handle = thread::spawn(move || -> std::io::Result<()> { + let mut buf = [0u8; 512]; + let (n, peer) = server.recv_from(&mut buf)?; + println!("Server received: {}", String::from_utf8_lossy(&buf[..n])); + server.send_to(&buf[..n], peer)?; + Ok(()) + }); + + let client = UdpSocket::v4_dgram()?; + let msg = b"hello via udp"; + client.send_to(msg, server_addr)?; + let mut buf = [0u8; 512]; + let (n, _) = client.recv_from(&mut buf)?; + println!("Client received: {}", String::from_utf8_lossy(&buf[..n])); + + handle.join().unwrap()?; + Ok(()) +} diff --git a/nex-core/src/bitfield.rs b/nex-core/src/bitfield.rs new file mode 100644 index 0000000..723aa5f --- /dev/null +++ b/nex-core/src/bitfield.rs @@ -0,0 +1,387 @@ +//! Provides type aliases for various primitive integer types +//! +//! These types are aliased to the next largest of \[`u8`, `u16`, `u32`, `u64`\] +//! +//! All aliases for types larger than `u8` contain a `be` or `le` suffix. These specify whether the +//! value is big or little endian, respectively. + +#![allow(non_camel_case_types)] +/// Represents an unsigned, 1-bit integer. +pub type u1 = u8; +/// Represents an unsigned, 2-bit integer. +pub type u2 = u8; +/// Represents an unsigned, 3-bit integer. +pub type u3 = u8; +/// Represents an unsigned, 4-bit integer. +pub type u4 = u8; +/// Represents an unsigned, 5-bit integer. +pub type u5 = u8; +/// Represents an unsigned, 6-bit integer. +pub type u6 = u8; +/// Represents an unsigned, 7-bit integer. +pub type u7 = u8; +/// Represents an unsigned 9-bit integer. +pub type u9be = u16; +/// Represents an unsigned 10-bit integer. +pub type u10be = u16; +/// Represents an unsigned 11-bit integer. +pub type u11be = u16; +/// Represents an unsigned 12-bit integer. +pub type u12be = u16; +/// Represents an unsigned 13-bit integer. +pub type u13be = u16; +/// Represents an unsigned 14-bit integer. +pub type u14be = u16; +/// Represents an unsigned 15-bit integer. +pub type u15be = u16; +/// Represents an unsigned 16-bit integer. +pub type u16be = u16; +/// Represents an unsigned 17-bit integer. +pub type u17be = u32; +/// Represents an unsigned 18-bit integer. +pub type u18be = u32; +/// Represents an unsigned 19-bit integer. +pub type u19be = u32; +/// Represents an unsigned 20-bit integer. +pub type u20be = u32; +/// Represents an unsigned 21-bit integer. +pub type u21be = u32; +/// Represents an unsigned 22-bit integer. +pub type u22be = u32; +/// Represents an unsigned 23-bit integer. +pub type u23be = u32; +/// Represents an unsigned 24-bit integer. +pub type u24be = u32; +/// Represents an unsigned 25-bit integer. +pub type u25be = u32; +/// Represents an unsigned 26-bit integer. +pub type u26be = u32; +/// Represents an unsigned 27-bit integer. +pub type u27be = u32; +/// Represents an unsigned 28-bit integer. +pub type u28be = u32; +/// Represents an unsigned 29-bit integer. +pub type u29be = u32; +/// Represents an unsigned 30-bit integer. +pub type u30be = u32; +/// Represents an unsigned 31-bit integer. +pub type u31be = u32; +/// Represents an unsigned 32-bit integer. +pub type u32be = u32; +/// Represents an unsigned 33-bit integer. +pub type u33be = u64; +/// Represents an unsigned 34-bit integer. +pub type u34be = u64; +/// Represents an unsigned 35-bit integer. +pub type u35be = u64; +/// Represents an unsigned 36-bit integer. +pub type u36be = u64; +/// Represents an unsigned 37-bit integer. +pub type u37be = u64; +/// Represents an unsigned 38-bit integer. +pub type u38be = u64; +/// Represents an unsigned 39-bit integer. +pub type u39be = u64; +/// Represents an unsigned 40-bit integer. +pub type u40be = u64; +/// Represents an unsigned 41-bit integer. +pub type u41be = u64; +/// Represents an unsigned 42-bit integer. +pub type u42be = u64; +/// Represents an unsigned 43-bit integer. +pub type u43be = u64; +/// Represents an unsigned 44-bit integer. +pub type u44be = u64; +/// Represents an unsigned 45-bit integer. +pub type u45be = u64; +/// Represents an unsigned 46-bit integer. +pub type u46be = u64; +/// Represents an unsigned 47-bit integer. +pub type u47be = u64; +/// Represents an unsigned 48-bit integer. +pub type u48be = u64; +/// Represents an unsigned 49-bit integer. +pub type u49be = u64; +/// Represents an unsigned 50-bit integer. +pub type u50be = u64; +/// Represents an unsigned 51-bit integer. +pub type u51be = u64; +/// Represents an unsigned 52-bit integer. +pub type u52be = u64; +/// Represents an unsigned 53-bit integer. +pub type u53be = u64; +/// Represents an unsigned 54-bit integer. +pub type u54be = u64; +/// Represents an unsigned 55-bit integer. +pub type u55be = u64; +/// Represents an unsigned 56-bit integer. +pub type u56be = u64; +/// Represents an unsigned 57-bit integer. +pub type u57be = u64; +/// Represents an unsigned 58-bit integer. +pub type u58be = u64; +/// Represents an unsigned 59-bit integer. +pub type u59be = u64; +/// Represents an unsigned 60-bit integer. +pub type u60be = u64; +/// Represents an unsigned 61-bit integer. +pub type u61be = u64; +/// Represents an unsigned 62-bit integer. +pub type u62be = u64; +/// Represents an unsigned 63-bit integer. +pub type u63be = u64; +/// Represents an unsigned 64-bit integer. +pub type u64be = u64; +/// Represents an unsigned 9-bit integer. +pub type u9le = u16; +/// Represents an unsigned 10-bit integer. +pub type u10le = u16; +/// Represents an unsigned 11-bit integer. +pub type u11le = u16; +/// Represents an unsigned 12-bit integer. +pub type u12le = u16; +/// Represents an unsigned 13-bit integer. +pub type u13le = u16; +/// Represents an unsigned 14-bit integer. +pub type u14le = u16; +/// Represents an unsigned 15-bit integer. +pub type u15le = u16; +/// Represents an unsigned 16-bit integer. +pub type u16le = u16; +/// Represents an unsigned 17-bit integer. +pub type u17le = u32; +/// Represents an unsigned 18-bit integer. +pub type u18le = u32; +/// Represents an unsigned 19-bit integer. +pub type u19le = u32; +/// Represents an unsigned 20-bit integer. +pub type u20le = u32; +/// Represents an unsigned 21-bit integer. +pub type u21le = u32; +/// Represents an unsigned 22-bit integer. +pub type u22le = u32; +/// Represents an unsigned 23-bit integer. +pub type u23le = u32; +/// Represents an unsigned 24-bit integer. +pub type u24le = u32; +/// Represents an unsigned 25-bit integer. +pub type u25le = u32; +/// Represents an unsigned 26-bit integer. +pub type u26le = u32; +/// Represents an unsigned 27-bit integer. +pub type u27le = u32; +/// Represents an unsigned 28-bit integer. +pub type u28le = u32; +/// Represents an unsigned 29-bit integer. +pub type u29le = u32; +/// Represents an unsigned 30-bit integer. +pub type u30le = u32; +/// Represents an unsigned 31-bit integer. +pub type u31le = u32; +/// Represents an unsigned 32-bit integer. +pub type u32le = u32; +/// Represents an unsigned 33-bit integer. +pub type u33le = u64; +/// Represents an unsigned 34-bit integer. +pub type u34le = u64; + +/// Represents an unsigned 35-bit integer. +pub type u35le = u64; + +/// Represents an unsigned 36-bit integer. +pub type u36le = u64; +/// Represents an unsigned 37-bit integer. +pub type u37le = u64; +/// Represents an unsigned 38-bit integer. +pub type u38le = u64; +/// Represents an unsigned 39-bit integer. +pub type u39le = u64; +/// Represents an unsigned 40-bit integer. +pub type u40le = u64; +/// Represents an unsigned 41-bit integer. +pub type u41le = u64; +/// Represents an unsigned 42-bit integer. +pub type u42le = u64; +/// Represents an unsigned 43-bit integer. +pub type u43le = u64; +/// Represents an unsigned 44-bit integer. +pub type u44le = u64; +/// Represents an unsigned 45-bit integer. +pub type u45le = u64; +/// Represents an unsigned 46-bit integer. +pub type u46le = u64; +/// Represents an unsigned 47-bit integer. +pub type u47le = u64; +/// Represents an unsigned 48-bit integer. +pub type u48le = u64; +/// Represents an unsigned 49-bit integer. +pub type u49le = u64; +/// Represents an unsigned 50-bit integer. +pub type u50le = u64; +/// Represents an unsigned 51-bit integer. +pub type u51le = u64; +/// Represents an unsigned 52-bit integer. +pub type u52le = u64; +/// Represents an unsigned 53-bit integer. +pub type u53le = u64; +/// Represents an unsigned 54-bit integer. +pub type u54le = u64; +/// Represents an unsigned 55-bit integer. +pub type u55le = u64; +/// Represents an unsigned 56-bit integer. +pub type u56le = u64; +/// Represents an unsigned 57-bit integer. +pub type u57le = u64; +/// Represents an unsigned 58-bit integer. +pub type u58le = u64; +/// Represents an unsigned 59-bit integer. +pub type u59le = u64; +/// Represents an unsigned 60-bit integer. +pub type u60le = u64; +/// Represents an unsigned 61-bit integer. +pub type u61le = u64; +/// Represents an unsigned 62-bit integer. +pub type u62le = u64; +/// Represents an unsigned 63-bit integer. +pub type u63le = u64; +/// Represents an unsigned 64-bit integer. +pub type u64le = u64; +/// Represents an unsigned 9-bit integer in host endianness. +pub type u9he = u16; +/// Represents an unsigned 10-bit integer in host endianness. +pub type u10he = u16; +/// Represents an unsigned 11-bit integer in host endianness. +pub type u11he = u16; +/// Represents an unsigned 12-bit integer in host endianness. +pub type u12he = u16; +/// Represents an unsigned 13-bit integer in host endianness. +pub type u13he = u16; +/// Represents an unsigned 14-bit integer in host endianness. +pub type u14he = u16; +/// Represents an unsigned 15-bit integer in host endianness. +pub type u15he = u16; +/// Represents an unsigned 16-bit integer in host endianness. +pub type u16he = u16; +/// Represents an unsigned 17-bit integer in host endianness. +pub type u17he = u32; +/// Represents an unsigned 18-bit integer in host endianness. +pub type u18he = u32; +/// Represents an unsigned 19-bit integer in host endianness. +pub type u19he = u32; +/// Represents an unsigned 20-bit integer in host endianness. +pub type u20he = u32; +/// Represents an unsigned 21-bit integer in host endianness. +pub type u21he = u32; +/// Represents an unsigned 22-bit integer in host endianness. +pub type u22he = u32; +/// Represents an unsigned 23-bit integer in host endianness. +pub type u23he = u32; +/// Represents an unsigned 24-bit integer in host endianness. +pub type u24he = u32; +/// Represents an unsigned 25-bit integer in host endianness. +pub type u25he = u32; +/// Represents an unsigned 26-bit integer in host endianness. +pub type u26he = u32; +/// Represents an unsigned 27-bit integer in host endianness. +pub type u27he = u32; +/// Represents an unsigned 28-bit integer in host endianness. +pub type u28he = u32; +/// Represents an unsigned 29-bit integer in host endianness. +pub type u29he = u32; +/// Represents an unsigned 30-bit integer in host endianness. +pub type u30he = u32; +/// Represents an unsigned 31-bit integer in host endianness. +pub type u31he = u32; +/// Represents an unsigned 32-bit integer in host endianness. +pub type u32he = u32; +/// Represents an unsigned 33-bit integer in host endianness. +pub type u33he = u64; +/// Represents an unsigned 34-bit integer in host endianness. +pub type u34he = u64; +/// Represents an unsigned 35-bit integer in host endianness. +pub type u35he = u64; +/// Represents an unsigned 36-bit integer in host endianness. +pub type u36he = u64; +/// Represents an unsigned 37-bit integer in host endianness. +pub type u37he = u64; +/// Represents an unsigned 38-bit integer in host endianness. +pub type u38he = u64; +/// Represents an unsigned 39-bit integer in host endianness. +pub type u39he = u64; +/// Represents an unsigned 40-bit integer in host endianness. +pub type u40he = u64; +/// Represents an unsigned 41-bit integer in host endianness. +pub type u41he = u64; +/// Represents an unsigned 42-bit integer in host endianness. +pub type u42he = u64; +/// Represents an unsigned 43-bit integer in host endianness. +pub type u43he = u64; +/// Represents an unsigned 44-bit integer in host endianness. +pub type u44he = u64; +/// Represents an unsigned 45-bit integer in host endianness. +pub type u45he = u64; +/// Represents an unsigned 46-bit integer in host endianness. +pub type u46he = u64; +/// Represents an unsigned 47-bit integer in host endianness. +pub type u47he = u64; +/// Represents an unsigned 48-bit integer in host endianness. +pub type u48he = u64; +/// Represents an unsigned 49-bit integer in host endianness. +pub type u49he = u64; +/// Represents an unsigned 50-bit integer in host endianness. +pub type u50he = u64; +/// Represents an unsigned 51-bit integer in host endianness. +pub type u51he = u64; +/// Represents an unsigned 52-bit integer in host endianness. +pub type u52he = u64; +/// Represents an unsigned 53-bit integer in host endianness. +pub type u53he = u64; +/// Represents an unsigned 54-bit integer in host endianness. +pub type u54he = u64; +/// Represents an unsigned 55-bit integer in host endianness. +pub type u55he = u64; +/// Represents an unsigned 56-bit integer in host endianness. +pub type u56he = u64; +/// Represents an unsigned 57-bit integer in host endianness. +pub type u57he = u64; +/// Represents an unsigned 58-bit integer in host endianness. +pub type u58he = u64; +/// Represents an unsigned 59-bit integer in host endianness. +pub type u59he = u64; +/// Represents an unsigned 60-bit integer in host endianness. +pub type u60he = u64; +/// Represents an unsigned 61-bit integer in host endianness. +pub type u61he = u64; +/// Represents an unsigned 62-bit integer in host endianness. +pub type u62he = u64; +/// Represents an unsigned 63-bit integer in host endianness. +pub type u63he = u64; +/// Represents an unsigned 64-bit integer in host endianness. +pub type u64he = u64; + +pub mod utils { + + pub fn u24be_from_bytes(bytes: [u8; 3]) -> super::u24be { + (u32::from(bytes[0]) << 16) | (u32::from(bytes[1]) << 8) | u32::from(bytes[2]) + } + + pub fn u24be_to_bytes(value: u32) -> [u8; 3] { + [ + ((value >> 16) & 0xFF) as u8, + ((value >> 8) & 0xFF) as u8, + (value & 0xFF) as u8, + ] + } +} +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn alias_sizes() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(std::mem::size_of::(), 4); + assert_eq!(std::mem::size_of::(), 8); + } +} + diff --git a/nex-core/src/ip.rs b/nex-core/src/ip.rs index c6decdb..575ceac 100644 --- a/nex-core/src/ip.rs +++ b/nex-core/src/ip.rs @@ -1,21 +1,35 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -pub use netdev::ipnet::*; +/// Returns [`true`] if the address appears to be globally routable. +pub fn is_global_ip(ip_addr: &IpAddr) -> bool { + match ip_addr { + IpAddr::V4(ip) => is_global_ipv4(ip), + IpAddr::V6(ip) => is_global_ipv6(ip), + } +} +/// Returns [`true`] if the address appears to be globally reachable +/// as specified by the [IANA IPv4 Special-Purpose Address Registry]. pub fn is_global_ipv4(ipv4_addr: &Ipv4Addr) -> bool { !(ipv4_addr.octets()[0] == 0 // "This network" || ipv4_addr.is_private() - || matches!(ipv4_addr.octets(), [169, 254, ..]) + || is_shared_ipv4(ipv4_addr) || ipv4_addr.is_loopback() || ipv4_addr.is_link_local() // addresses reserved for future protocols (`192.0.0.0/24`) - ||(ipv4_addr.octets()[0] == 192 && ipv4_addr.octets()[1] == 0 && ipv4_addr.octets()[2] == 0) + // .9 and .10 are documented as globally reachable so they're excluded + || ( + ipv4_addr.octets()[0] == 192 && ipv4_addr.octets()[1] == 0 && ipv4_addr.octets()[2] == 0 + && ipv4_addr.octets()[3] != 9 && ipv4_addr.octets()[3] != 10 + ) || ipv4_addr.is_documentation() - || ipv4_addr.octets()[0] == 198 && (ipv4_addr.octets()[1] & 0xfe) == 18 - || ipv4_addr.octets()[0] & 240 == 240 && !ipv4_addr.is_broadcast() + || is_benchmarking_ipv4(ipv4_addr) + || is_reserved_ipv4(ipv4_addr) || ipv4_addr.is_broadcast()) } +/// Returns [`true`] if the address appears to be globally reachable +/// as specified by the [IANA IPv6 Special-Purpose Address Registry]. pub fn is_global_ipv6(ipv6_addr: &Ipv6Addr) -> bool { !(ipv6_addr.is_unspecified() || ipv6_addr.is_loopback() @@ -37,12 +51,90 @@ pub fn is_global_ipv6(ipv6_addr: &Ipv6Addr) -> bool { // AS112-v6 (`2001:4:112::/48`) || matches!(ipv6_addr.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) // ORCHIDv2 (`2001:20::/28`) - || matches!(ipv6_addr.segments(), [0x2001, b, _, _, _, _, _, _] if b >= 0x20 && b <= 0x2F) + // Drone Remote ID Protocol Entity Tags (DETs) Prefix (`2001:30::/28`)` + || matches!(ipv6_addr.segments(), [0x2001, b, _, _, _, _, _, _] if b >= 0x20 && b <= 0x3F) )) - // Reserved for documentation - || ((ipv6_addr.segments()[0] == 0x2001) && (ipv6_addr.segments()[1] == 0x2) && (ipv6_addr.segments()[2] == 0)) - // Unique Local Address - || ((ipv6_addr.segments()[0] & 0xfe00) == 0xfc00) - // unicast address with link-local scope (`fc00::/7`) - || ((ipv6_addr.segments()[0] & 0xffc0) == 0xfe80)) + // 6to4 (`2002::/16`) – it's not explicitly documented as globally reachable, + // IANA says N/A. + || matches!(ipv6_addr.segments(), [0x2002, _, _, _, _, _, _, _]) + || is_documentation_ipv6(ipv6_addr) + || ipv6_addr.is_unique_local() + || ipv6_addr.is_unicast_link_local()) +} + +/// Returns [`true`] if this address is part of the Shared Address Space defined in +/// [IETF RFC 6598] (`100.64.0.0/10`). +/// +/// [IETF RFC 6598]: https://tools.ietf.org/html/rfc6598 +fn is_shared_ipv4(ipv4_addr: &Ipv4Addr) -> bool { + ipv4_addr.octets()[0] == 100 && (ipv4_addr.octets()[1] & 0b1100_0000 == 0b0100_0000) +} + +/// Returns [`true`] if this address part of the `198.18.0.0/15` range, which is reserved for +/// network devices benchmarking. +fn is_benchmarking_ipv4(ipv4_addr: &Ipv4Addr) -> bool { + ipv4_addr.octets()[0] == 198 && (ipv4_addr.octets()[1] & 0xfe) == 18 +} + +/// Returns [`true`] if this address is reserved by IANA for future use. +fn is_reserved_ipv4(ipv4_addr: &Ipv4Addr) -> bool { + ipv4_addr.octets()[0] & 240 == 240 && !ipv4_addr.is_broadcast() +} + +/// Returns [`true`] if this is an address reserved for documentation +/// (`2001:db8::/32` and `3fff::/20`). +fn is_documentation_ipv6(ipv6_addr: &Ipv6Addr) -> bool { + matches!( + ipv6_addr.segments(), + [0x2001, 0xdb8, ..] | [0x3fff, 0..=0x0fff, ..] + ) } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_is_global_ipv4() { + let global = Ipv4Addr::new(1, 1, 1, 1); // Cloudflare + let private = Ipv4Addr::new(192, 168, 1, 1); + let loopback = Ipv4Addr::new(127, 0, 0, 1); + let shared = Ipv4Addr::new(100, 64, 0, 1); // RFC6598 + let doc = Ipv4Addr::new(192, 0, 2, 1); // Documentation + + assert!(is_global_ipv4(&global)); + assert!(!is_global_ipv4(&private)); + assert!(!is_global_ipv4(&loopback)); + assert!(!is_global_ipv4(&shared)); + assert!(!is_global_ipv4(&doc)); + } + + #[test] + fn test_is_global_ipv6() { + let global = Ipv6Addr::new(0x2606, 0x4700, 0, 0, 0, 0, 0, 0x1111); // Cloudflare + let loopback = Ipv6Addr::LOCALHOST; + let unspecified = Ipv6Addr::UNSPECIFIED; + let unique_local = Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1); + let doc = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1); // Documentation + + assert!(is_global_ipv6(&global)); + assert!(!is_global_ipv6(&loopback)); + assert!(!is_global_ipv6(&unspecified)); + assert!(!is_global_ipv6(&unique_local)); + assert!(!is_global_ipv6(&doc)); + } + + #[test] + fn test_is_global_ip() { + let ip_v4 = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); + let ip_v6 = IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0, 0, 0, 0, 0, 0x1111)); // Cloudflare + let ip_private = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + let ip_ula = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)); + + assert!(is_global_ip(&ip_v4)); + assert!(is_global_ip(&ip_v6)); + assert!(!is_global_ip(&ip_private)); + assert!(!is_global_ip(&ip_ula)); + } +} \ No newline at end of file diff --git a/nex-core/src/lib.rs b/nex-core/src/lib.rs index 171bd9e..c107dfb 100644 --- a/nex-core/src/lib.rs +++ b/nex-core/src/lib.rs @@ -8,3 +8,4 @@ pub mod gateway; pub mod interface; pub mod ip; pub mod mac; +pub mod bitfield; diff --git a/nex-datalink/Cargo.toml b/nex-datalink/Cargo.toml index d1a33a8..34d95e2 100644 --- a/nex-datalink/Cargo.toml +++ b/nex-datalink/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" [dependencies] libc = { workspace = true } +bytes = { workspace = true } netdev = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } pcap = { version = "2.0", optional = true } diff --git a/nex-datalink/src/bindings/bpf.rs b/nex-datalink/src/bindings/bpf.rs index 6f1fd73..ee594e2 100644 --- a/nex-datalink/src/bindings/bpf.rs +++ b/nex-datalink/src/bindings/bpf.rs @@ -144,4 +144,4 @@ pub struct bpf_hdr { #[cfg(not(windows))] extern "C" { pub fn ioctl(d: libc::c_int, request: libc::c_ulong, ...) -> libc::c_int; -} +} \ No newline at end of file diff --git a/nex-datalink/src/lib.rs b/nex-datalink/src/lib.rs index fb3a038..90f714a 100644 --- a/nex-datalink/src/lib.rs +++ b/nex-datalink/src/lib.rs @@ -179,3 +179,21 @@ pub trait RawReceiver: Send { /// Get the next ethernet frame in the channel. fn next(&mut self) -> io::Result<&[u8]>; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn config_default_values() { + let cfg = Config::default(); + assert_eq!(cfg.write_buffer_size, 4096); + assert_eq!(cfg.read_buffer_size, 4096); + assert_eq!(cfg.read_timeout, None); + assert_eq!(cfg.write_timeout, None); + assert_eq!(cfg.channel_type, ChannelType::Layer2); + assert_eq!(cfg.bpf_fd_attempts, 1000); + assert!(cfg.linux_fanout.is_none()); + assert!(cfg.promiscuous); + } +} diff --git a/nex-datalink/src/linux.rs b/nex-datalink/src/linux.rs index 719cdd1..7fd2e3c 100644 --- a/nex-datalink/src/linux.rs +++ b/nex-datalink/src/linux.rs @@ -45,8 +45,8 @@ pub struct Config { /// The write timeout. Defaults to None. pub write_timeout: Option, - /// Specifies whether to read packets at the datalink layer or network layer. - /// NOTE FIXME Currently ignored. + /// Selects the socket mode: datalink (Layer2) or network (Layer3). + /// This setting is only consulted when the socket is created. /// Defaults to Layer2. pub channel_type: super::ChannelType, @@ -57,6 +57,17 @@ pub struct Config { pub promiscuous: bool, } +#[inline] +fn poll_timeout_ms(timeout: Option) -> libc::c_int { + timeout + .map(|to| { + let ms = (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000); + ms.clamp(i64::from(libc::c_int::MIN), i64::from(libc::c_int::MAX)) + as libc::c_int + }) + .unwrap_or(-1) +} + impl<'a> From<&'a super::Config> for Config { fn from(config: &super::Config) -> Config { Config { @@ -192,7 +203,6 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result io::Result io::Result, write_buffer: Vec, - _channel_type: super::ChannelType, send_addr: libc::sockaddr_ll, send_addr_len: usize, timeout: Option, @@ -241,11 +249,7 @@ impl RawSender for RawSenderImpl { // poll timeout in milliseconds // -1: wait indefinitely - let timeout_ms = self - .timeout - .as_ref() - .map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000)) - .unwrap_or(-1); + let timeout_ms = poll_timeout_ms(self.timeout); for chunk in mut_slice[..min].chunks_mut(packet_size) { func(chunk); @@ -297,11 +301,7 @@ impl RawSender for RawSenderImpl { // poll timeout in milliseconds // -1: wait indefinitely - let timeout_ms = self - .timeout - .as_ref() - .map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000)) - .unwrap_or(-1); + let timeout_ms = poll_timeout_ms(self.timeout); let ret = unsafe { libc::poll( @@ -338,7 +338,6 @@ impl RawSender for RawSenderImpl { struct RawReceiverImpl { socket: Arc, read_buffer: Vec, - _channel_type: super::ChannelType, timeout: Option, } @@ -353,11 +352,7 @@ impl RawReceiver for RawReceiverImpl { // poll timeout in milliseconds // -1: wait indefinitely - let timeout_ms = self - .timeout - .as_ref() - .map(|to| (to.tv_sec as i64 * 1000) + (to.tv_nsec as i64 / 1_000_000)) - .unwrap_or(-1); + let timeout_ms = poll_timeout_ms(self.timeout); let ret = unsafe { libc::poll( diff --git a/nex-macro-helper/Cargo.toml b/nex-macro-helper/Cargo.toml deleted file mode 100644 index 92a937f..0000000 --- a/nex-macro-helper/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "nex-macro-helper" -version.workspace = true -edition.workspace = true -authors.workspace = true -description = "A helper crate for nex-macro. Not intended for direct use." -repository = "https://github.com/shellrow/nex" -readme = "../README.md" -keywords = ["network", "packet"] -categories = ["network-programming"] -license = "MIT" - -[dependencies] -nex-core = { workspace = true } diff --git a/nex-macro-helper/src/lib.rs b/nex-macro-helper/src/lib.rs deleted file mode 100644 index c07a706..0000000 --- a/nex-macro-helper/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! Helper crate for `nex-macro`. - -#![deny(missing_docs)] -#![deny(warnings)] - -extern crate nex_core; - -pub mod packet; -pub mod types; diff --git a/nex-macro-helper/src/packet.rs b/nex-macro-helper/src/packet.rs deleted file mode 100644 index 2041e69..0000000 --- a/nex-macro-helper/src/packet.rs +++ /dev/null @@ -1,231 +0,0 @@ -//! Packet helpers for `nex-macro`. - -extern crate alloc; -use alloc::vec; - -use core::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; -use nex_core; - -/// Represents a generic network packet. -pub trait Packet { - /// Retrieve the underlying buffer for the packet. - fn packet(&self) -> &[u8]; - /// Retrieve the payload for the packet. - fn payload(&self) -> &[u8]; -} - -/// Blanket impl for Boxed objects -impl Packet for alloc::boxed::Box { - /// Retrieve the underlying buffer for the packet. - fn packet(&self) -> &[u8] { - self.deref().packet() - } - /// Retrieve the payload for the packet. - fn payload(&self) -> &[u8] { - self.deref().payload() - } -} - -impl Packet for &T { - /// Retrieve the underlying buffer for the packet. - fn packet(&self) -> &[u8] { - (*self).packet() - } - /// Retrieve the payload for the packet. - fn payload(&self) -> &[u8] { - (*self).payload() - } -} - -/// Represents a generic, mutable, network packet. -pub trait MutablePacket: Packet { - /// Retrieve the underlying, mutable, buffer for the packet. - fn packet_mut(&mut self) -> &mut [u8]; - /// Retrieve the mutable payload for the packet. - fn payload_mut(&mut self) -> &mut [u8]; - /// Initialize this packet by cloning another. - fn clone_from(&mut self, other: &T) { - use core::ptr; - - assert!(self.packet().len() >= other.packet().len()); - unsafe { - ptr::copy_nonoverlapping( - other.packet().as_ptr(), - self.packet_mut().as_mut_ptr(), - other.packet().len(), - ); - } - } -} - -/// Used to convert on-the-wire packets to their #\[packet\] equivalent. -pub trait FromPacket: Packet { - /// The type of the packet to convert from. - type T; - /// Converts a wire-format packet to #\[packet\] struct format. - fn from_packet(&self) -> Self::T; -} - -/// Used to find the calculated size of the packet. This is used for occasions where the underlying -/// buffer is not the same length as the packet itself. -pub trait PacketSize: Packet { - /// Get the calculated size of the packet. - fn packet_size(&self) -> usize; -} - -macro_rules! impl_index { - ($t:ident, $index_t:ty, $output_t:ty) => { - impl<'p> Index<$index_t> for $t<'p> { - type Output = $output_t; - - #[inline] - fn index(&self, index: $index_t) -> &$output_t { - &self.as_slice().index(index) - } - } - }; -} - -macro_rules! impl_index_mut { - ($t:ident, $index_t:ty, $output_t:ty) => { - impl<'p> IndexMut<$index_t> for $t<'p> { - #[inline] - fn index_mut(&mut self, index: $index_t) -> &mut $output_t { - self.as_mut_slice().index_mut(index) - } - } - }; -} - -/// Packet data. -#[derive(PartialEq)] -pub enum PacketData<'p> { - /// A packet owns its contents. - Owned(vec::Vec), - /// A packet borrows its contents. - Borrowed(&'p [u8]), -} - -impl<'p> PacketData<'p> { - /// Get a slice of the packet data. - #[inline] - pub fn as_slice(&self) -> &[u8] { - match self { - &PacketData::Owned(ref data) => data.deref(), - &PacketData::Borrowed(ref data) => data, - } - } - /// No-op - returns `self`. - #[inline] - pub fn to_immutable(self) -> PacketData<'p> { - self - } - /// A length of the packet data. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } -} - -impl_index!(PacketData, usize, u8); -impl_index!(PacketData, Range, [u8]); -impl_index!(PacketData, RangeTo, [u8]); -impl_index!(PacketData, RangeFrom, [u8]); -impl_index!(PacketData, RangeFull, [u8]); - -/// Mutable packet data. -#[derive(PartialEq)] -pub enum MutPacketData<'p> { - /// Owned mutable packet data. - Owned(vec::Vec), - /// Borrowed mutable packet data. - Borrowed(&'p mut [u8]), -} - -impl<'p> MutPacketData<'p> { - /// Get packet data as a slice. - #[inline] - pub fn as_slice(&self) -> &[u8] { - match self { - &MutPacketData::Owned(ref data) => data.deref(), - &MutPacketData::Borrowed(ref data) => data, - } - } - /// Get packet data as a mutable slice. - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [u8] { - match self { - &mut MutPacketData::Owned(ref mut data) => data.deref_mut(), - &mut MutPacketData::Borrowed(ref mut data) => data, - } - } - /// Get an immutable version of packet data. - #[inline] - pub fn to_immutable(self) -> PacketData<'p> { - match self { - MutPacketData::Owned(data) => PacketData::Owned(data), - MutPacketData::Borrowed(data) => PacketData::Borrowed(data), - } - } - /// Get a length of data in the packet. - #[inline] - pub fn len(&self) -> usize { - self.as_slice().len() - } -} - -impl_index!(MutPacketData, usize, u8); -impl_index!(MutPacketData, Range, [u8]); -impl_index!(MutPacketData, RangeTo, [u8]); -impl_index!(MutPacketData, RangeFrom, [u8]); -impl_index!(MutPacketData, RangeFull, [u8]); - -impl_index_mut!(MutPacketData, usize, u8); -impl_index_mut!(MutPacketData, Range, [u8]); -impl_index_mut!(MutPacketData, RangeTo, [u8]); -impl_index_mut!(MutPacketData, RangeFrom, [u8]); -impl_index_mut!(MutPacketData, RangeFull, [u8]); - -/// Used to convert a type to primitive values representing it. -pub trait PrimitiveValues { - /// A tuple of types, to represent the current value. - type T; - /// Convert a value to primitive types representing it. - fn to_primitive_values(&self) -> Self::T; -} - -impl PrimitiveValues for nex_core::mac::MacAddr { - type T = (u8, u8, u8, u8, u8, u8); - #[inline] - fn to_primitive_values(&self) -> (u8, u8, u8, u8, u8, u8) { - (self.0, self.1, self.2, self.3, self.4, self.5) - } -} - -impl PrimitiveValues for std::net::Ipv4Addr { - type T = (u8, u8, u8, u8); - #[inline] - fn to_primitive_values(&self) -> (u8, u8, u8, u8) { - let octets = self.octets(); - - (octets[0], octets[1], octets[2], octets[3]) - } -} - -impl PrimitiveValues for std::net::Ipv6Addr { - type T = (u16, u16, u16, u16, u16, u16, u16, u16); - #[inline] - fn to_primitive_values(&self) -> (u16, u16, u16, u16, u16, u16, u16, u16) { - let segments = self.segments(); - ( - segments[0], - segments[1], - segments[2], - segments[3], - segments[4], - segments[5], - segments[6], - segments[7], - ) - } -} diff --git a/nex-macro-helper/src/types.rs b/nex-macro-helper/src/types.rs deleted file mode 100644 index f4d1e6b..0000000 --- a/nex-macro-helper/src/types.rs +++ /dev/null @@ -1,477 +0,0 @@ -//! Provides type aliases for various primitive integer types -//! -//! These types are aliased to the next largest of \[`u8`, `u16`, `u32`, `u64`\], and purely serve as -//! hints for the `#[packet]` macro to enable the generation of the correct bit manipulations to -//! get the value out of a packet. -//! -//! They should NOT be used outside of data types marked as `#[packet]`. -//! -//! All aliases for types larger than `u8` contain a `be` or `le` suffix. These specify whether the -//! value is big or little endian, respectively. When using `set_*()` and `get_*()` methods, host -//! endianness should be used - the methods will convert as appropriate. - -#![allow(non_camel_case_types)] -/// Represents an unsigned, 1-bit integer. -pub type u1 = u8; -/// Represents an unsigned, 2-bit integer. -pub type u2 = u8; -/// Represents an unsigned, 3-bit integer. -pub type u3 = u8; -/// Represents an unsigned, 4-bit integer. -pub type u4 = u8; -/// Represents an unsigned, 5-bit integer. -pub type u5 = u8; -/// Represents an unsigned, 6-bit integer. -pub type u6 = u8; -/// Represents an unsigned, 7-bit integer. -pub type u7 = u8; -/// Represents an unsigned 9-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u9be = u16; -/// Represents an unsigned 10-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u10be = u16; -/// Represents an unsigned 11-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u11be = u16; -/// Represents an unsigned 12-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u12be = u16; -/// Represents an unsigned 13-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u13be = u16; -/// Represents an unsigned 14-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u14be = u16; -/// Represents an unsigned 15-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u15be = u16; -/// Represents an unsigned 16-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u16be = u16; -/// Represents an unsigned 17-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u17be = u32; -/// Represents an unsigned 18-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u18be = u32; -/// Represents an unsigned 19-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u19be = u32; -/// Represents an unsigned 20-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u20be = u32; -/// Represents an unsigned 21-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u21be = u32; -/// Represents an unsigned 22-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u22be = u32; -/// Represents an unsigned 23-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u23be = u32; -/// Represents an unsigned 24-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u24be = u32; -/// Represents an unsigned 25-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u25be = u32; -/// Represents an unsigned 26-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u26be = u32; -/// Represents an unsigned 27-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u27be = u32; -/// Represents an unsigned 28-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u28be = u32; -/// Represents an unsigned 29-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u29be = u32; -/// Represents an unsigned 30-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u30be = u32; -/// Represents an unsigned 31-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u31be = u32; -/// Represents an unsigned 32-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u32be = u32; -/// Represents an unsigned 33-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u33be = u64; -/// Represents an unsigned 34-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u34be = u64; -/// Represents an unsigned 35-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u35be = u64; -/// Represents an unsigned 36-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u36be = u64; -/// Represents an unsigned 37-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u37be = u64; -/// Represents an unsigned 38-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u38be = u64; -/// Represents an unsigned 39-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u39be = u64; -/// Represents an unsigned 40-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u40be = u64; -/// Represents an unsigned 41-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u41be = u64; -/// Represents an unsigned 42-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u42be = u64; -/// Represents an unsigned 43-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u43be = u64; -/// Represents an unsigned 44-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u44be = u64; -/// Represents an unsigned 45-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u45be = u64; -/// Represents an unsigned 46-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u46be = u64; -/// Represents an unsigned 47-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u47be = u64; -/// Represents an unsigned 48-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u48be = u64; -/// Represents an unsigned 49-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u49be = u64; -/// Represents an unsigned 50-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u50be = u64; -/// Represents an unsigned 51-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u51be = u64; -/// Represents an unsigned 52-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u52be = u64; -/// Represents an unsigned 53-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u53be = u64; -/// Represents an unsigned 54-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u54be = u64; -/// Represents an unsigned 55-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u55be = u64; -/// Represents an unsigned 56-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u56be = u64; -/// Represents an unsigned 57-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u57be = u64; -/// Represents an unsigned 58-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u58be = u64; -/// Represents an unsigned 59-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u59be = u64; -/// Represents an unsigned 60-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u60be = u64; -/// Represents an unsigned 61-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u61be = u64; -/// Represents an unsigned 62-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u62be = u64; -/// Represents an unsigned 63-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u63be = u64; -/// Represents an unsigned 64-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as big-endian, but accessors/mutators will return/take host-order values. -pub type u64be = u64; -/// Represents an unsigned 9-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u9le = u16; -/// Represents an unsigned 10-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u10le = u16; -/// Represents an unsigned 11-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u11le = u16; -/// Represents an unsigned 12-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u12le = u16; -/// Represents an unsigned 13-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u13le = u16; -/// Represents an unsigned 14-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u14le = u16; -/// Represents an unsigned 15-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u15le = u16; -/// Represents an unsigned 16-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u16le = u16; -/// Represents an unsigned 17-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u17le = u32; -/// Represents an unsigned 18-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u18le = u32; -/// Represents an unsigned 19-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u19le = u32; -/// Represents an unsigned 20-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u20le = u32; -/// Represents an unsigned 21-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u21le = u32; -/// Represents an unsigned 22-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u22le = u32; -/// Represents an unsigned 23-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u23le = u32; -/// Represents an unsigned 24-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u24le = u32; -/// Represents an unsigned 25-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u25le = u32; -/// Represents an unsigned 26-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u26le = u32; -/// Represents an unsigned 27-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u27le = u32; -/// Represents an unsigned 28-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u28le = u32; -/// Represents an unsigned 29-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u29le = u32; -/// Represents an unsigned 30-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u30le = u32; -/// Represents an unsigned 31-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u31le = u32; -/// Represents an unsigned 32-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u32le = u32; -/// Represents an unsigned 33-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u33le = u64; -/// Represents an unsigned 34-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u34le = u64; - -/// Represents an unsigned 35-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u35le = u64; - -/// Represents an unsigned 36-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u36le = u64; -/// Represents an unsigned 37-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u37le = u64; -/// Represents an unsigned 38-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u38le = u64; -/// Represents an unsigned 39-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u39le = u64; -/// Represents an unsigned 40-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u40le = u64; -/// Represents an unsigned 41-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u41le = u64; -/// Represents an unsigned 42-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u42le = u64; -/// Represents an unsigned 43-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u43le = u64; -/// Represents an unsigned 44-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u44le = u64; -/// Represents an unsigned 45-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u45le = u64; -/// Represents an unsigned 46-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u46le = u64; -/// Represents an unsigned 47-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u47le = u64; -/// Represents an unsigned 48-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u48le = u64; -/// Represents an unsigned 49-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u49le = u64; -/// Represents an unsigned 50-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u50le = u64; -/// Represents an unsigned 51-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u51le = u64; -/// Represents an unsigned 52-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u52le = u64; -/// Represents an unsigned 53-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u53le = u64; -/// Represents an unsigned 54-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u54le = u64; -/// Represents an unsigned 55-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u55le = u64; -/// Represents an unsigned 56-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u56le = u64; -/// Represents an unsigned 57-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u57le = u64; -/// Represents an unsigned 58-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u58le = u64; -/// Represents an unsigned 59-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u59le = u64; -/// Represents an unsigned 60-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u60le = u64; -/// Represents an unsigned 61-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u61le = u64; -/// Represents an unsigned 62-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u62le = u64; -/// Represents an unsigned 63-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u63le = u64; -/// Represents an unsigned 64-bit integer. nex #\[packet\]-derived structs using this type will -/// hold it in memory as little-endian, but accessors/mutators will return/take host-order values. -pub type u64le = u64; -/// Represents an unsigned 9-bit integer in host endianness. -pub type u9he = u16; -/// Represents an unsigned 10-bit integer in host endianness. -pub type u10he = u16; -/// Represents an unsigned 11-bit integer in host endianness. -pub type u11he = u16; -/// Represents an unsigned 12-bit integer in host endianness. -pub type u12he = u16; -/// Represents an unsigned 13-bit integer in host endianness. -pub type u13he = u16; -/// Represents an unsigned 14-bit integer in host endianness. -pub type u14he = u16; -/// Represents an unsigned 15-bit integer in host endianness. -pub type u15he = u16; -/// Represents an unsigned 16-bit integer in host endianness. -pub type u16he = u16; -/// Represents an unsigned 17-bit integer in host endianness. -pub type u17he = u32; -/// Represents an unsigned 18-bit integer in host endianness. -pub type u18he = u32; -/// Represents an unsigned 19-bit integer in host endianness. -pub type u19he = u32; -/// Represents an unsigned 20-bit integer in host endianness. -pub type u20he = u32; -/// Represents an unsigned 21-bit integer in host endianness. -pub type u21he = u32; -/// Represents an unsigned 22-bit integer in host endianness. -pub type u22he = u32; -/// Represents an unsigned 23-bit integer in host endianness. -pub type u23he = u32; -/// Represents an unsigned 24-bit integer in host endianness. -pub type u24he = u32; -/// Represents an unsigned 25-bit integer in host endianness. -pub type u25he = u32; -/// Represents an unsigned 26-bit integer in host endianness. -pub type u26he = u32; -/// Represents an unsigned 27-bit integer in host endianness. -pub type u27he = u32; -/// Represents an unsigned 28-bit integer in host endianness. -pub type u28he = u32; -/// Represents an unsigned 29-bit integer in host endianness. -pub type u29he = u32; -/// Represents an unsigned 30-bit integer in host endianness. -pub type u30he = u32; -/// Represents an unsigned 31-bit integer in host endianness. -pub type u31he = u32; -/// Represents an unsigned 32-bit integer in host endianness. -pub type u32he = u32; -/// Represents an unsigned 33-bit integer in host endianness. -pub type u33he = u64; -/// Represents an unsigned 34-bit integer in host endianness. -pub type u34he = u64; -/// Represents an unsigned 35-bit integer in host endianness. -pub type u35he = u64; -/// Represents an unsigned 36-bit integer in host endianness. -pub type u36he = u64; -/// Represents an unsigned 37-bit integer in host endianness. -pub type u37he = u64; -/// Represents an unsigned 38-bit integer in host endianness. -pub type u38he = u64; -/// Represents an unsigned 39-bit integer in host endianness. -pub type u39he = u64; -/// Represents an unsigned 40-bit integer in host endianness. -pub type u40he = u64; -/// Represents an unsigned 41-bit integer in host endianness. -pub type u41he = u64; -/// Represents an unsigned 42-bit integer in host endianness. -pub type u42he = u64; -/// Represents an unsigned 43-bit integer in host endianness. -pub type u43he = u64; -/// Represents an unsigned 44-bit integer in host endianness. -pub type u44he = u64; -/// Represents an unsigned 45-bit integer in host endianness. -pub type u45he = u64; -/// Represents an unsigned 46-bit integer in host endianness. -pub type u46he = u64; -/// Represents an unsigned 47-bit integer in host endianness. -pub type u47he = u64; -/// Represents an unsigned 48-bit integer in host endianness. -pub type u48he = u64; -/// Represents an unsigned 49-bit integer in host endianness. -pub type u49he = u64; -/// Represents an unsigned 50-bit integer in host endianness. -pub type u50he = u64; -/// Represents an unsigned 51-bit integer in host endianness. -pub type u51he = u64; -/// Represents an unsigned 52-bit integer in host endianness. -pub type u52he = u64; -/// Represents an unsigned 53-bit integer in host endianness. -pub type u53he = u64; -/// Represents an unsigned 54-bit integer in host endianness. -pub type u54he = u64; -/// Represents an unsigned 55-bit integer in host endianness. -pub type u55he = u64; -/// Represents an unsigned 56-bit integer in host endianness. -pub type u56he = u64; -/// Represents an unsigned 57-bit integer in host endianness. -pub type u57he = u64; -/// Represents an unsigned 58-bit integer in host endianness. -pub type u58he = u64; -/// Represents an unsigned 59-bit integer in host endianness. -pub type u59he = u64; -/// Represents an unsigned 60-bit integer in host endianness. -pub type u60he = u64; -/// Represents an unsigned 61-bit integer in host endianness. -pub type u61he = u64; -/// Represents an unsigned 62-bit integer in host endianness. -pub type u62he = u64; -/// Represents an unsigned 63-bit integer in host endianness. -pub type u63he = u64; -/// Represents an unsigned 64-bit integer in host endianness. -pub type u64he = u64; diff --git a/nex-macro/Cargo.toml b/nex-macro/Cargo.toml deleted file mode 100644 index 7accdf6..0000000 --- a/nex-macro/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "nex-macro" -version.workspace = true -edition.workspace = true -authors.workspace = true -description = "A macro for generating packet structures used by nex-packet. Not intended for direct use." -repository = "https://github.com/shellrow/nex" -readme = "../README.md" -keywords = ["network", "packet"] -categories = ["network-programming"] -license = "MIT" - -[lib] -proc-macro = true - -[dependencies] -proc-macro2 = "1.0" -quote = "1.0" -syn = { version = "2.0", features = ["full"] } -regex = "1.11" - -[dev-dependencies] -nex-macro-helper = { workspace = true } diff --git a/nex-macro/src/decorator.rs b/nex-macro/src/decorator.rs deleted file mode 100644 index ccbc511..0000000 --- a/nex-macro/src/decorator.rs +++ /dev/null @@ -1,1740 +0,0 @@ -//! Implements the #[packet] decorator. - -use crate::util::{ - operations, to_little_endian, to_mutator, Endianness, GetOperation, SetOperation, -}; -use core::iter::FromIterator; -use proc_macro2::{Group, Span}; -use quote::{quote, ToTokens}; -use regex::Regex; -use syn::{spanned::Spanned, Error}; - -#[derive(Debug, PartialEq, Eq)] -enum EndiannessSpecified { - No, - Yes, -} - -/// Lower and upper bounds of a payload. -/// Represented as strings since they may involve functions. -struct PayloadBounds { - lower: String, - upper: String, -} - -#[derive(Clone, Debug, PartialEq)] -enum Type { - /// Any of the `u*` types from `nex_macro::types::*`. - Primitive(String, usize, Endianness), - /// Any type of the form `Vec`. - Vector(Box), - /// Any type which isn't a primitive or a vector. - Misc(String), -} - -#[derive(Clone, Debug)] -struct Field { - name: String, - span: Span, - ty: Type, - packet_length: Option, - struct_length: Option, - is_payload: bool, - construct_with: Option>, -} - -#[derive(Clone, Debug)] -pub struct Packet { - base_name: String, - fields: Vec, -} - -impl Packet { - fn packet_name_mut(&self) -> String { - format!("Mutable{}Packet", self.base_name) - } - fn packet_name(&self) -> String { - format!("{}Packet", self.base_name) - } -} - -#[inline] -pub fn generate_packet( - s: &syn::DataStruct, - name: String, -) -> Result { - let packet = make_packet(s, name)?; - let structs = generate_packet_struct(&packet); - let (ts_packet_impls, payload_bounds, packet_size) = generate_packet_impls(&packet)?; - let ts_size_impls = generate_packet_size_impls(&packet, &packet_size)?; - let ts_trait_impls = generate_packet_trait_impls(&packet, &payload_bounds)?; - let ts_iterables = generate_iterables(&packet)?; - let ts_converters = generate_converters(&packet)?; - let ts_debug_impls = generate_debug_impls(&packet)?; - let tts = quote! { - #structs - #ts_packet_impls - #ts_size_impls - #ts_trait_impls - #ts_iterables - #ts_converters - #ts_debug_impls - }; - Ok(tts) -} - -#[inline] -fn generate_packet_struct(packet: &Packet) -> proc_macro2::TokenStream { - let items = &[ - (packet.packet_name(), "PacketData"), - (packet.packet_name_mut(), "MutPacketData"), - ]; - let tts: Vec<_> = items - .iter() - .map(|(name, packet_data)| { - let name = syn::Ident::new(&name, Span::call_site()); - let packet_data = syn::Ident::new(packet_data, Span::call_site()); - quote! { - #[derive(PartialEq)] - /// A structure enabling manipulation of on the wire packets - pub struct #name<'p> { - packet: ::nex_macro_helper::packet::#packet_data<'p>, - } - } - }) - .collect(); - quote! { - #(#tts)* - } -} - -#[inline] -fn make_type(ty_str: String, endianness_important: bool) -> Result { - if let Some((size, endianness, spec)) = parse_ty(&ty_str[..]) { - if !endianness_important || size <= 8 || spec == EndiannessSpecified::Yes { - Ok(Type::Primitive(ty_str, size, endianness)) - } else { - Err("endianness must be specified for types of size >= 8".to_owned()) - } - } else if ty_str.starts_with("Vec<") { - let ty = make_type( - String::from(&ty_str[4..ty_str.len() - 1]), - endianness_important, - ); - match ty { - Ok(ty) => Ok(Type::Vector(Box::new(ty))), - Err(e) => Err(e), - } - } else if ty_str.starts_with("&") { - Err(format!("invalid type: {}", ty_str)) - } else { - Ok(Type::Misc(ty_str)) - } -} - -#[inline] -fn make_packet(s: &syn::DataStruct, name: String) -> Result { - let mut fields = Vec::new(); - let mut payload_span = None; - let sfields = &s.fields; - for field in sfields { - let field_name = match &field.ident { - Some(name) => name.to_string(), - None => { - return Err(Error::new( - field.ty.span(), - "all fields in a packet must be named", - )); - } - }; - let mut construct_with = None; - let mut is_payload = false; - let mut packet_length = None; - let mut struct_length = None; - for attr in &field.attrs { - match attr.meta { - syn::Meta::Path(ref p) => { - if let Some(ident) = p.get_ident() { - if ident == "payload" { - if payload_span.is_some() { - return Err(Error::new( - p.span(), - "packet may not have multiple payloads", - )); - } - is_payload = true; - payload_span = Some(field.span()); - } - } - } - syn::Meta::NameValue(ref name_value) => { - if let Some(ident) = name_value.path.get_ident() { - if ident == "length_fn" { - if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(ref s), - .. - }) = name_value.value - { - packet_length = Some(s.value() + "(&_self.to_immutable())"); - } else { - return Err(Error::new( - name_value.path.span(), - "#[length_fn] should be used as #[length_fn = \ - \"name_of_function\"]", - )); - } - } else if ident == "length" { - // get literal - if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(ref s), - .. - }) = name_value.value - { - let field_names: Vec = sfields - .iter() - .filter_map(|field| { - field.ident.as_ref().map(|name| name.to_string()).and_then( - |name| { - if name == field_name { - None - } else { - Some(name) - } - }, - ) - }) - .collect(); - // Convert to tokens - let expr = s.parse::()?; - let tts = expr.to_token_stream(); - let tt_tokens: Vec<_> = tts.into_iter().collect(); - // Parse and replace fields - let tokens_packet = parse_length_expr(&tt_tokens, &field_names)?; - let parsed = quote! { #(#tokens_packet)* }; - packet_length = Some(parsed.to_string()); - } else { - return Err(Error::new( - name_value.value.span(), - "#[length] should be used as #[length = \ - \"field_name and/or arithmetic expression\"]", - )); - } - } else { - return Err(Error::new( - ident.span(), - &format!("Unknown meta/namevalue option '{}'", ident), - )); - } - } - } - syn::Meta::List(ref l) => { - if let Some(ident) = l.path.get_ident() { - if ident == "construct_with" { - let mut some_construct_with = Vec::new(); - - l.parse_nested_meta(|meta| { - if let Some(ident) = meta.path.get_ident() { - // #[construct_with(,...)] - let ty_str = ident.to_string(); - match make_type(ty_str, false) { - Ok(ty) => { - some_construct_with.push(ty); - Ok(()) - } - Err(e) => Err(meta.error(e)), - } - } else { - // Not an ident. Something else, likely a path. - Err(meta.error("expected ident")) - } - }) - .map_err(|mut err| { - err.combine(Error::new( - l.span(), - "#[construct_with] should be of the form \ - #[construct_with()]", - )); - err - })?; - - if some_construct_with.is_empty() { - return Err(Error::new( - l.span(), - "#[construct_with] must have at least one argument", - )); - } - construct_with = Some(some_construct_with); - } else { - return Err(Error::new( - ident.span(), - &format!("unknown attribute: {}", ident), - )); - } - } else { - return Err(Error::new( - l.path.span(), - "meta-list attribute has unexpected type (not an ident)", - )); - } - } - } - } - - let ty = match make_type(ty_to_string(&field.ty), true) { - Ok(ty) => ty, - Err(e) => { - return Err(Error::new(field.ty.span(), &format!("{}", e))); - } - }; - - match ty { - Type::Vector(_) => { - struct_length = if let Some(construct_with) = construct_with.as_ref() { - let mut inner_size = 0; - for arg in construct_with.iter() { - if let Type::Primitive(ref _ty_str, size, _endianness) = *arg { - inner_size += size; - } else { - return Err(Error::new( - field.span(), - "arguments to #[construct_with] must be primitives", - )); - } - } - if inner_size % 8 != 0 { - return Err(Error::new( - field.span(), - "types in #[construct_with] for vec must be add up to a multiple of 8 bits", - )); - } - inner_size /= 8; // bytes not bits - - Some(format!("_packet.{}.len() * {}", field_name, inner_size).to_owned()) - } else { - Some(format!("_packet.{}.len()", field_name).to_owned()) - }; - if !is_payload && packet_length.is_none() { - return Err(Error::new( - field.ty.span(), - "variable length field must have #[length = \"\"] or \ - #[length_fn = \"\"] attribute", - )); - } - } - Type::Misc(_) => { - if construct_with.is_none() { - return Err(Error::new( - field.ty.span(), - "non-primitive field types must specify #[construct_with]", - )); - } - } - _ => {} - } - - fields.push(Field { - name: field_name, - span: field.span(), - ty, - packet_length, - struct_length, - is_payload, - construct_with, - }); - } - - if payload_span.is_none() { - return Err(Error::new( - Span::call_site(), - "#[packet]'s must contain a payload", - )); - } - - Ok(Packet { - base_name: name, - fields, - }) -} - -/// Return the processed length expression for a packet. -#[inline] -fn parse_length_expr( - tts: &[proc_macro2::TokenTree], - field_names: &[String], -) -> Result, Error> { - use proc_macro2::TokenTree; - let error_msg = "Only field names, constants, integers, basic arithmetic expressions \ - (+ - * / %) and parentheses are allowed in the \"length\" attribute"; - let mut needs_constant: Option = None; - let mut has_constant = false; - let mut tokens_packet = Vec::new(); - for tt_token in tts { - match tt_token { - TokenTree::Ident(name) => { - if name.to_string().chars().any(|c| c.is_lowercase()) { - if field_names.contains(&name.to_string()) { - let tts: syn::Expr = - syn::parse_str(&format!("_self.get_{}() as usize", name))?; - let mut modified_packet_tokens: Vec<_> = - tts.to_token_stream().into_iter().collect(); - tokens_packet.append(&mut modified_packet_tokens); - } else { - if let None = needs_constant { - needs_constant = Some(tt_token.span()); - } - tokens_packet.push(tt_token.clone()); - } - } - // Constants are only recognized if they are all uppercase - else { - let tts: syn::Expr = syn::parse_str(&format!("{} as usize", name))?; - let mut modified_packet_tokens: Vec<_> = - tts.to_token_stream().into_iter().collect(); - tokens_packet.append(&mut modified_packet_tokens); - has_constant = true; - } - } - TokenTree::Punct(_) => { - tokens_packet.push(tt_token.clone()); - } - TokenTree::Literal(lit) => { - // must be an integer - if syn::parse_str::(&lit.to_string()).is_err() { - return Err(Error::new(lit.span(), error_msg)); - } - tokens_packet.push(tt_token.clone()); - } - TokenTree::Group(ref group) => { - let ts: Vec<_> = group.stream().into_iter().collect(); - let tts = parse_length_expr(&ts, field_names)?; - let mut new_group = Group::new( - group.delimiter(), - proc_macro2::TokenStream::from_iter(tts.into_iter()), - ); - new_group.set_span(group.span()); - let tt = TokenTree::Group(new_group); - tokens_packet.push(tt); - } - }; - } - - if let Some(span) = needs_constant { - if !has_constant { - return Err(Error::new( - span, - "Field name must be a member of the struct and not the field itself", - )); - } - } - - Ok(tokens_packet) -} - -#[inline] -fn generate_packet_impl( - packet: &Packet, - mutable: bool, - name: String, -) -> Result<(proc_macro2::TokenStream, PayloadBounds, String), Error> { - let mut bit_offset = 0; - let mut offset_fns_packet = Vec::new(); - let mut offset_fns_struct = Vec::new(); - let mut accessors = "".to_owned(); - let mut mutators = "".to_owned(); - let mut payload_bounds = None; - for (idx, field) in packet.fields.iter().enumerate() { - let mut co = current_offset(bit_offset, &offset_fns_packet[..]); - - if field.is_payload { - let mut upper_bound_str = "".to_owned(); - if field.packet_length.is_some() { - upper_bound_str = - format!("{} + {}", co.clone(), field.packet_length.as_ref().unwrap()); - } else { - if idx != packet.fields.len() - 1 { - return Err(Error::new( - field.span, - "#[payload] must specify a #[length_fn], unless it is the \ - last field of a packet", - )); - } - } - payload_bounds = Some(PayloadBounds { - lower: co.clone(), - upper: upper_bound_str, - }); - } - match field.ty { - Type::Primitive(ref ty_str, size, endianness) => { - let mut ops = operations(bit_offset % 8, size).unwrap(); - let target_endianness = if cfg!(target_endian = "little") { - Endianness::Little - } else { - Endianness::Big - }; - - if endianness == Endianness::Little - || (target_endianness == Endianness::Little && endianness == Endianness::Host) - { - ops = to_little_endian(ops); - } - - mutators = mutators - + &generate_mutator_str( - &field.name[..], - &ty_str[..], - &co[..], - &to_mutator(&ops[..])[..], - None, - )[..]; - accessors = accessors - + &generate_accessor_str(&field.name[..], &ty_str[..], &co[..], &ops[..], None) - [..]; - bit_offset += size; - } - Type::Vector(ref inner_ty) => handle_vector_field( - &field, - &mut bit_offset, - &offset_fns_packet[..], - &mut co, - &name, - &mut mutators, - &mut accessors, - inner_ty, - )?, - Type::Misc(ref ty_str) => handle_misc_field( - &field, - &mut bit_offset, - &offset_fns_packet[..], - &mut co, - &name, - &mut mutators, - &mut accessors, - &ty_str, - )?, - } - if field.packet_length.is_some() { - offset_fns_packet.push(field.packet_length.as_ref().unwrap().clone()); - } - if field.struct_length.is_some() { - offset_fns_struct.push(field.struct_length.as_ref().unwrap().clone()); - } - } - - fn generate_set_fields(packet: &Packet) -> String { - let mut set_fields = String::new(); - for field in &packet.fields { - match field.ty { - Type::Vector(_) => { - set_fields = set_fields - + &format!("_self.set_{field}(&packet.{field});\n", field = field.name)[..]; - } - _ => { - set_fields = set_fields - + &format!("_self.set_{field}(packet.{field});\n", field = field.name)[..]; - } - } - } - - set_fields - } - - let populate = if mutable { - let set_fields = generate_set_fields(&packet); - let imm_name = packet.packet_name(); - format!( - "/// Populates a {name}Packet using a {name} structure - #[inline] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn populate(&mut self, packet: &{name}) {{ - let _self = self; - {set_fields} - }}", - name = &imm_name[..imm_name.len() - 6], - set_fields = set_fields - ) - } else { - "".to_owned() - }; - - // If there are no variable length fields defined, then `_packet` is not used, hence - // the leading underscore - let packet_size_struct = format!( - "/// The size (in bytes) of a {base_name} instance when converted into - /// a byte-array - #[inline] - pub fn packet_size(_packet: \ - &{base_name}) -> usize {{ - {struct_size} - }}", - base_name = packet.base_name, - struct_size = current_offset(bit_offset, &offset_fns_struct[..]) - ); - - let byte_size = if bit_offset % 8 == 0 { - bit_offset / 8 - } else { - (bit_offset / 8) + 1 - }; - - let s = format!("impl<'a> {name}<'a> {{ - /// Constructs a new {name}. If the provided buffer is less than the minimum required - /// packet size, this will return None. - #[inline] - pub fn new<'p>(packet: &'p {mut} [u8]) -> Option<{name}<'p>> {{ - if packet.len() >= {name}::minimum_packet_size() {{ - use ::nex_macro_helper::packet::{cap_mut}PacketData; - Some({name} {{ packet: {cap_mut}PacketData::Borrowed(packet) }}) - }} else {{ - None - }} - }} - - /// Constructs a new {name}. If the provided buffer is less than the minimum required - /// packet size, this will return None. With this constructor the {name} will - /// own its own data and the underlying buffer will be dropped when the {name} is. - pub fn owned(packet: Vec) -> Option<{name}<'static>> {{ - if packet.len() >= {name}::minimum_packet_size() {{ - use ::nex_macro_helper::packet::{cap_mut}PacketData; - Some({name} {{ packet: {cap_mut}PacketData::Owned(packet) }}) - }} else {{ - None - }} - }} - - /// Maps from a {name} to a {imm_name} - #[inline] - pub fn to_immutable<'p>(&'p self) -> {imm_name}<'p> {{ - use ::nex_macro_helper::packet::PacketData; - {imm_name} {{ packet: PacketData::Borrowed(self.packet.as_slice()) }} - }} - - /// Maps from a {name} to a {imm_name} while consuming the source - #[inline] - pub fn consume_to_immutable(self) -> {imm_name}<'a> {{ - {imm_name} {{ packet: self.packet.to_immutable() }} - }} - - /// The minimum size (in bytes) a packet of this type can be. It's based on the total size - /// of the fixed-size fields. - #[inline] - pub const fn minimum_packet_size() -> usize {{ - {byte_size} - }} - - {packet_size_struct} - - {populate} - - {accessors} - - {mutators} - }}", name = name, - imm_name = packet.packet_name(), - mut = if mutable { "mut" } else { "" }, - cap_mut = if mutable { "Mut" } else { "" }, - byte_size = byte_size, - accessors = accessors, - mutators = if mutable { &mutators[..] } else { "" }, - populate = populate, - packet_size_struct = packet_size_struct - ); - - let stmt: syn::Stmt = syn::parse_str(&s).expect("parse fn generate_packet_impl failed"); - let ts = quote! { - #stmt - }; - - Ok(( - ts, - payload_bounds.unwrap(), - current_offset(bit_offset, &offset_fns_packet[..]), - )) -} - -#[inline] -fn generate_packet_impls( - packet: &Packet, -) -> Result<(proc_macro2::TokenStream, PayloadBounds, String), Error> { - let mut ret = None; - let mut tts = Vec::new(); - for (mutable, name) in vec![ - (false, packet.packet_name()), - (true, packet.packet_name_mut()), - ] { - let (tokens, bounds, size) = generate_packet_impl(packet, mutable, name)?; - tts.push(tokens); - ret = Some((bounds, size)); - } - let tokens = quote! { #(#tts)* }; - - ret.map(|(bounds, size)| (tokens, bounds, size)) - .ok_or_else(|| Error::new(Span::call_site(), "generate_packet_impls failed")) -} - -#[inline] -fn generate_packet_size_impls( - packet: &Packet, - size: &str, -) -> Result { - let tts: Result, _> = [packet.packet_name(), packet.packet_name_mut()] - .iter() - .map(|name| { - let s = format!( - " - impl<'a> ::nex_macro_helper::packet::PacketSize for {name}<'a> {{ - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn packet_size(&self) -> usize {{ - let _self = self; - {size} - }} - }} - ", - name = name, - size = size - ); - syn::parse_str::(&s) - }) - .collect(); - let tts = tts?; - Ok(quote! { #(#tts)* }) -} - -#[inline] -fn generate_packet_trait_impls( - packet: &Packet, - payload_bounds: &PayloadBounds, -) -> Result { - let items = [ - (packet.packet_name_mut(), "Mutable", "_mut", "mut"), - (packet.packet_name_mut(), "", "", ""), - (packet.packet_name(), "", "", ""), - ]; - let tts: Result, _> = items - .iter() - .map(|(name, mutable, u_mut, mut_)| { - let mut pre = "".to_owned(); - let mut start = "".to_owned(); - let mut end = "".to_owned(); - if !payload_bounds.lower.is_empty() { - pre = pre + &format!("let start = {};", payload_bounds.lower)[..]; - start = "start".to_owned(); - } - if !payload_bounds.upper.is_empty() { - pre = pre - + &format!( - "let end = ::core::cmp::min({}, _self.packet.len());", - payload_bounds.upper - )[..]; - end = "end".to_owned(); - } - let s = format!( - "impl<'a> ::nex_macro_helper::packet::{mutable}Packet for {name}<'a> {{ - #[inline] - fn packet{u_mut}<'p>(&'p {mut_} self) -> &'p {mut_} [u8] {{ &{mut_} self.packet[..] }} - - #[inline] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn payload{u_mut}<'p>(&'p {mut_} self) -> &'p {mut_} [u8] {{ - let _self = self; - {pre} - if _self.packet.len() <= {start} {{ - return &{mut_} []; - }} - &{mut_} _self.packet[{start}..{end}] - }} - }}", - name = name, - start = start, - end = end, - pre = pre, - mutable = mutable, - u_mut = u_mut, - mut_ = mut_ - ); - syn::parse_str::(&s) - }) - .collect(); - let tts = tts?; - Ok(quote! { #(#tts)* }) -} - -#[inline] -fn generate_iterables(packet: &Packet) -> Result { - let name = &packet.base_name; - - let ts1 = format!( - " - /// Used to iterate over a slice of `{name}Packet`s - pub struct {name}Iterable<'a> {{ - buf: &'a [u8], - }} - ", - name = name - ); - - let ts2 = format!( - " - impl<'a> Iterator for {name}Iterable<'a> {{ - type Item = {name}Packet<'a>; - - fn next(&mut self) -> Option<{name}Packet<'a>> {{ - use nex_macro_helper::packet::PacketSize; - use core::cmp::min; - if self.buf.len() > 0 {{ - if let Some(ret) = {name}Packet::new(self.buf) {{ - let start = min(ret.packet_size(), self.buf.len()); - self.buf = &self.buf[start..]; - return Some(ret); - }} - }} - - None - }} - - fn size_hint(&self) -> (usize, Option) {{ - (0, None) - }} - }} - ", - name = name - ); - let ts1: syn::Stmt = syn::parse_str(&ts1)?; - let ts2: syn::Stmt = syn::parse_str(&ts2)?; - Ok(quote! { - #ts1 - #ts2 - }) -} - -#[inline] -fn generate_converters(packet: &Packet) -> Result { - let get_fields = generate_get_fields(packet); - - let tts: Result, _> = [packet.packet_name(), packet.packet_name_mut()] - .iter() - .map(|name| { - let s = format!( - " - impl<'p> ::nex_macro_helper::packet::FromPacket for {packet}<'p> {{ - type T = {name}; - #[inline] - fn from_packet(&self) -> {name} {{ - use nex_macro_helper::packet::Packet; - let _self = self; - {name} {{ - {get_fields} - }} - }} - }}", - packet = name, - name = packet.base_name, - get_fields = get_fields - ); - syn::parse_str::(&s) - }) - .collect(); - let tts = tts?; - Ok(quote! { #(#tts)* }) -} - -#[inline] -fn generate_debug_impls(packet: &Packet) -> Result { - let mut field_fmt_str = String::new(); - let mut get_fields = String::new(); - - for field in &packet.fields { - if !field.is_payload { - field_fmt_str = format!("{}{} : {{:?}}, ", field_fmt_str, field.name); - get_fields = format!("{}, _self.get_{}()", get_fields, field.name); - } - } - - let tts: Result, _> = [packet.packet_name(), packet.packet_name_mut()] - .iter() - .map(|packet| { - let s = format!( - " - impl<'p> ::core::fmt::Debug for {packet}<'p> {{ - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {{ - let _self = self; - write!(fmt, - \"{packet} {{{{ {field_fmt_str} }}}}\" - {get_fields} - ) - }} - }}", - packet = packet, - field_fmt_str = field_fmt_str, - get_fields = get_fields - ); - syn::parse_str::(&s) - }) - .collect(); - let tts = tts?; - Ok(quote! { #(#tts)* }) -} - -#[inline] -fn handle_misc_field( - field: &Field, - bit_offset: &mut usize, - offset_fns: &[String], - co: &mut String, - name: &str, - mutators: &mut String, - accessors: &mut String, - ty_str: &str, -) -> Result<(), Error> { - let mut inner_accessors = String::new(); - let mut inner_mutators = String::new(); - let mut get_args = String::new(); - let mut set_args = String::new(); - for (i, arg) in field - .construct_with - .as_ref() - .expect("misc field as ref") - .iter() - .enumerate() - { - if let Type::Primitive(ref ty_str, size, endianness) = *arg { - let mut ops = operations(*bit_offset % 8, size).unwrap(); - let target_endianness = if cfg!(target_endian = "little") { - Endianness::Little - } else { - Endianness::Big - }; - - if endianness == Endianness::Little - || (target_endianness == Endianness::Little && endianness == Endianness::Host) - { - ops = to_little_endian(ops); - } - - let arg_name = format!("arg{}", i); - inner_accessors = inner_accessors - + &generate_accessor_str( - &arg_name[..], - &ty_str[..], - &co[..], - &ops[..], - Some(&name[..]), - )[..]; - inner_mutators = inner_mutators - + &generate_mutator_str( - &arg_name[..], - &ty_str[..], - &co[..], - &to_mutator(&ops[..])[..], - Some(&name[..]), - )[..]; - get_args = format!("{}get_{}(&self), ", get_args, arg_name); - set_args = format!("{}set_{}(_self, vals.{});\n", set_args, arg_name, i); - *bit_offset += size; - // Current offset needs to be recalculated for each arg - *co = current_offset(*bit_offset, offset_fns); - } else { - return Err(Error::new( - field.span, - "arguments to #[construct_with] must be primitives", - )); - } - } - *mutators = format!( - "{mutators} - /// Set the value of the {name} field. - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn set_{name}(&mut self, val: {ty_str}) {{ - use nex_macro_helper::packet::PrimitiveValues; - let _self = self; - {inner_mutators} - - let vals = val.to_primitive_values(); - - {set_args} - }} - ", - mutators = &mutators[..], - name = field.name, - ty_str = ty_str, - inner_mutators = inner_mutators, - set_args = set_args - ); - let ctor = if field.construct_with.is_some() { - format!( - "{} {}::new({})", - inner_accessors, - ty_str, - &get_args[..get_args.len() - 2] - ) - } else { - format!( - "let current_offset = {}; - {}::new(&_self.packet[current_offset..])", - co, ty_str - ) - }; - *accessors = format!( - "{accessors} - /// Get the value of the {name} field - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}(&self) -> {ty_str} {{ - {ctor} - }} - ", - accessors = accessors, - name = field.name, - ty_str = ty_str, - ctor = ctor - ); - Ok(()) -} - -#[inline] -fn handle_vec_primitive( - inner_ty_str: &str, - size: usize, - field: &Field, - accessors: &mut String, - mutators: &mut String, - co: &mut String, -) -> Result<(), Error> { - if inner_ty_str == "u8" || (size % 8) == 0 { - let ops = operations(0, size).unwrap(); - if !field.is_payload { - let op_strings = generate_accessor_op_str("packet", inner_ty_str, &ops); - *accessors = format!("{accessors} - /// Get the value of the {name} field (copies contents) - #[inline] - #[allow(trivial_numeric_casts, unused_parens, unused_braces)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}(&self) -> Vec<{inner_ty_str}> {{ - use core::cmp::min; - let _self = self; - let current_offset = {co}; - let pkt_len = self.packet.len(); - let end = min(current_offset + {packet_length}, pkt_len); - - let packet = &_self.packet[current_offset..end]; - let mut vec: Vec<{inner_ty_str}> = Vec::with_capacity(packet.len() / {size}); - let mut co = 0; - for _ in 0..vec.capacity() {{ - vec.push({{ - {ops} - }}); - co += {size}; - }} - vec - }} - ", - accessors = accessors, - name = field.name, - co = co, - packet_length = field.packet_length.as_ref().unwrap(), - inner_ty_str = inner_ty_str, - ops = op_strings, - size = size / 8); - } - let check_len = if field.packet_length.is_some() { - format!( - "let len = {packet_length}; - assert!(vals.len() <= len);", - packet_length = field.packet_length.as_ref().unwrap() - ) - } else { - String::new() - }; - - let copy_vals = if inner_ty_str == "u8" { - // Efficient copy_from_slice (memcpy) - format!( - " - _self.packet[current_offset..current_offset + vals.len()] - .copy_from_slice(vals); - " - ) - } else { - // e.g. Vec -> Vec - let sop_strings = generate_sop_strings(&to_mutator(&ops)); - format!( - " - let mut co = current_offset; - for i in 0..vals.len() {{ - let val = vals[i]; - {sop} - co += {size}; - }}", - sop = sop_strings, - size = size / 8 - ) - }; - - *mutators = format!( - "{mutators} - /// Set the value of the {name} field (copies contents) - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn set_{name}(&mut self, vals: &[{inner_ty_str}]) {{ - let mut _self = self; - let current_offset = {co}; - - {check_len} - - {copy_vals} - }}", - mutators = mutators, - name = field.name, - co = co, - check_len = check_len, - inner_ty_str = inner_ty_str, - copy_vals = copy_vals - ); - Ok(()) - } else { - Err(Error::new( - field.span, - "unimplemented variable length field", - )) - } -} - -#[inline] -fn handle_vector_field( - field: &Field, - bit_offset: &mut usize, - offset_fns: &[String], - co: &mut String, - name: &str, - mutators: &mut String, - accessors: &mut String, - inner_ty: &Box, -) -> Result<(), Error> { - if !field.is_payload && !field.packet_length.is_some() { - return Err(Error::new( - field.span, - "variable length field must have #[length_fn = \"\"] attribute", - )); - } - if !field.is_payload { - *accessors = format!("{accessors} - /// Get the raw &[u8] value of the {name} field, without copying - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}_raw(&self) -> &[u8] {{ - use core::cmp::min; - let _self = self; - let current_offset = {co}; - let end = min(current_offset + {packet_length}, _self.packet.len()); - - &_self.packet[current_offset..end] - }} - ", - accessors = accessors, - name = field.name, - co = co, - packet_length = field.packet_length.as_ref().unwrap()); - *mutators = format!("{mutators} - /// Get the raw &mut [u8] value of the {name} field, without copying - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}_raw_mut(&mut self) -> &mut [u8] {{ - use core::cmp::min; - let _self = self; - let current_offset = {co}; - let end = min(current_offset + {packet_length}, _self.packet.len()); - - &mut _self.packet[current_offset..end] - }} - ", - mutators = mutators, - name = field.name, - co = co, - packet_length = field.packet_length.as_ref().unwrap()); - } - match **inner_ty { - Type::Primitive(ref inner_ty_str, _size, _endianness) => { - handle_vec_primitive(inner_ty_str, _size, field, accessors, mutators, co) - } - Type::Vector(_) => { - return Err(Error::new( - field.span, - "variable length fields may not contain vectors", - )); - } - Type::Misc(ref inner_ty_str) => { - if let Some(construct_with) = field.construct_with.as_ref() { - let mut inner_accessors = String::new(); - let mut inner_mutators = String::new(); - let mut get_args = String::new(); - let mut set_args = String::new(); - let mut inner_size = 0; - for (i, arg) in construct_with.iter().enumerate() { - if let Type::Primitive(ref ty_str, size, endianness) = *arg { - let mut ops = operations(*bit_offset % 8, size).unwrap(); - let target_endianness = if cfg!(target_endian = "little") { - Endianness::Little - } else { - Endianness::Big - }; - - if endianness == Endianness::Little - || (target_endianness == Endianness::Little - && endianness == Endianness::Host) - { - ops = to_little_endian(ops); - } - - inner_size += size; - let arg_name = format!("arg{}", i); - inner_accessors = inner_accessors - + &generate_accessor_with_offset_str( - &arg_name[..], - &ty_str[..], - &co[..], - &ops[..], - &name[..], - )[..]; - inner_mutators = inner_mutators - + &generate_mutator_with_offset_str( - &arg_name[..], - &ty_str[..], - &co[..], - &to_mutator(&ops[..])[..], - &name[..], - )[..]; - get_args = - format!("{}get_{}(&self, additional_offset), ", get_args, arg_name); - set_args = format!( - "{}set_{}(_self, vals.{}, additional_offset);\n", - set_args, arg_name, i - ); - *bit_offset += size; - // Current offset needs to be recalculated for each arg - *co = current_offset(*bit_offset, offset_fns); - } else { - return Err(Error::new( - field.span, - "arguments to #[construct_with] must be primitives", - )); - } - } - if inner_size % 8 != 0 { - return Err(Error::new( - field.span, - "types in #[construct_with] for vec must be add up to a multiple of 8 bits", - )); - } - inner_size /= 8; // bytes not bits - *mutators = format!( - "{mutators} - /// Set the value of the {name} field. - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn set_{name}(&mut self, vals: &Vec<{inner_ty_str}>) {{ - use nex_macro_helper::packet::PrimitiveValues; - let _self = self; - {inner_mutators} - let mut additional_offset = 0; - - for val in vals.into_iter() {{ - let vals = val.to_primitive_values(); - - {set_args} - - additional_offset += {inner_size}; - }} - }} - ", - mutators = &mutators[..], - name = field.name, - inner_ty_str = inner_ty_str, - inner_mutators = inner_mutators, - //packet_length = field.packet_length.as_ref().unwrap(), - inner_size = inner_size, - set_args = set_args - ); - *accessors = format!( - "{accessors} - /// Get the value of the {name} field - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}(&self) -> Vec<{inner_ty_str}> {{ - let _self = self; - let length = {packet_length}; - let vec_length = length.saturating_div({inner_size}); - let mut vec = Vec::with_capacity(vec_length); - - {inner_accessors} - - let mut additional_offset = 0; - - for vec_offset in 0..vec_length {{ - vec.push({inner_ty_str}::new({get_args})); - additional_offset += {inner_size}; - }} - - vec - }} - ", - accessors = accessors, - name = field.name, - inner_ty_str = inner_ty_str, - inner_accessors = inner_accessors, - packet_length = field.packet_length.as_ref().unwrap(), - inner_size = inner_size, - get_args = &get_args[..get_args.len() - 2] - ); - return Ok(()); - } - *accessors = format!("{accessors} - /// Get the value of the {name} field (copies contents) - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}(&self) -> Vec<{inner_ty_str}> {{ - use nex_macro_helper::packet::FromPacket; - use core::cmp::min; - let _self = self; - let current_offset = {co}; - let end = min(current_offset + {packet_length}, _self.packet.len()); - - {inner_ty_str}Iterable {{ - buf: &_self.packet[current_offset..end] - }}.map(|packet| packet.from_packet()) - .collect::>() - }} - - /// Get the value of the {name} field as iterator - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}_iter(&self) -> {inner_ty_str}Iterable {{ - use core::cmp::min; - let _self = self; - let current_offset = {co}; - let end = min(current_offset + {packet_length}, _self.packet.len()); - - {inner_ty_str}Iterable {{ - buf: &_self.packet[current_offset..end] - }} - }} - ", - accessors = accessors, - name = field.name, - co = co, - packet_length = field.packet_length.as_ref().unwrap(), - inner_ty_str = inner_ty_str); - *mutators = format!("{mutators} - /// Set the value of the {name} field (copies contents) - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn set_{name}(&mut self, vals: &[{inner_ty_str}]) {{ - use nex_macro_helper::packet::PacketSize; - let _self = self; - let mut current_offset = {co}; - let end = current_offset + {packet_length}; - for val in vals.into_iter() {{ - let mut packet = Mutable{inner_ty_str}Packet::new(&mut _self.packet[current_offset..]).unwrap(); - packet.populate(val); - current_offset += packet.packet_size(); - assert!(current_offset <= end); - }} - }} - ", - mutators = mutators, - name = field.name, - co = co, - packet_length = field.packet_length.as_ref().unwrap(), - inner_ty_str = inner_ty_str); - Ok(()) - } - } -} - -/// Given a type in the form `u([0-9]+)(be|le)?`, return a tuple of it's size and endianness -/// -/// If 1 <= size <= 8, Endianness will be Big. -fn parse_ty(ty: &str) -> Option<(usize, Endianness, EndiannessSpecified)> { - let re = Regex::new(r"^u([0-9]+)(be|le|he)?$").unwrap(); - let iter = match re.captures_iter(ty).next() { - Some(c) => c, - None => return None, - }; - if iter.len() == 3 || iter.len() == 2 { - let size = iter.get(1).unwrap().as_str(); - let (endianness, has_end) = if let Some(e) = iter.get(2) { - let e = e.as_str(); - if e == "be" { - (Endianness::Big, EndiannessSpecified::Yes) - } else if e == "he" { - (Endianness::Host, EndiannessSpecified::Yes) - } else { - (Endianness::Little, EndiannessSpecified::Yes) - } - } else { - (Endianness::Big, EndiannessSpecified::No) - }; - - if let Ok(sz) = size.parse() { - Some((sz, endianness, has_end)) - } else { - None - } - } else { - None - } -} - -fn ty_to_string(ty: &syn::Type) -> String { - // XXX this inserts extra spaces (ex: "Vec < u8 >") - let s = quote!(#ty).to_string(); - s.replace(" < ", "<").replace(" > ", ">").replace(" >", ">") -} - -#[test] -fn test_parse_ty() { - assert_eq!( - parse_ty("u8"), - Some((8, Endianness::Big, EndiannessSpecified::No)) - ); - assert_eq!( - parse_ty("u21be"), - Some((21, Endianness::Big, EndiannessSpecified::Yes)) - ); - assert_eq!( - parse_ty("u21le"), - Some((21, Endianness::Little, EndiannessSpecified::Yes)) - ); - assert_eq!( - parse_ty("u21he"), - Some((21, Endianness::Host, EndiannessSpecified::Yes)) - ); - assert_eq!( - parse_ty("u9"), - Some((9, Endianness::Big, EndiannessSpecified::No)) - ); - assert_eq!( - parse_ty("u16"), - Some((16, Endianness::Big, EndiannessSpecified::No)) - ); - assert_eq!(parse_ty("uable"), None); - assert_eq!(parse_ty("u21re"), None); - assert_eq!(parse_ty("i21be"), None); -} - -fn generate_sop_strings(operations: &[SetOperation]) -> String { - let mut op_strings = String::new(); - for (idx, sop) in operations.iter().enumerate() { - let pkt_replace = format!("_self.packet[co + {}]", idx); - let val_replace = "val"; - let sop = sop - .to_string() - .replace("{packet}", &pkt_replace[..]) - .replace("{val}", val_replace); - op_strings = op_strings + &sop[..] + ";\n"; - } - - op_strings -} - -enum AccessorMutator { - Accessor, - Mutator, -} - -fn generate_accessor_or_mutator_comment(name: &str, ty: &str, op_type: AccessorMutator) -> String { - let get_or_set = match op_type { - AccessorMutator::Accessor => "Get", - AccessorMutator::Mutator => "Set", - }; - if let Some((_, endianness, end_specified)) = parse_ty(ty) { - if end_specified == EndiannessSpecified::Yes { - let return_or_want = match op_type { - AccessorMutator::Accessor => "accessor returns", - AccessorMutator::Mutator => "mutator wants", - }; - let endian_str = match endianness { - Endianness::Big => "big-endian", - Endianness::Little => "little-endian", - Endianness::Host => "host-endian", - }; - - return format!( - "/// {get_or_set} the {name} field. This field is always stored {endian} - /// within the struct, but this {return_or_want} host order.", - get_or_set = get_or_set, - name = name, - endian = endian_str, - return_or_want = return_or_want - ); - } - } - format!( - "/// {get_or_set} the {name} field.", - get_or_set = get_or_set, - name = name - ) -} - -/// Given the name of a field, and a set of operations required to set that field, return -/// the Rust code required to set the field -fn generate_mutator_str( - name: &str, - ty: &str, - offset: &str, - operations: &[SetOperation], - inner: Option<&str>, -) -> String { - let op_strings = generate_sop_strings(operations); - - let mutator = if let Some(struct_name) = inner { - format!( - "#[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn set_{name}(_self: &mut {struct_name}, val: {ty}) {{ - let co = {co}; - {operations} - }}", - struct_name = struct_name, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) - } else { - let comment = generate_accessor_or_mutator_comment(name, ty, AccessorMutator::Mutator); - format!( - "{comment} - #[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn set_{name}(&mut self, val: {ty}) {{ - let _self = self; - let co = {co}; - {operations} - }}", - comment = comment, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) - }; - - mutator -} - -fn generate_mutator_with_offset_str( - name: &str, - ty: &str, - offset: &str, - operations: &[SetOperation], - inner: &str, -) -> String { - let op_strings = generate_sop_strings(operations); - - format!( - "#[inline] - #[allow(trivial_numeric_casts)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn set_{name}(_self: &mut {struct_name}, val: {ty}, offset: usize) {{ - let co = {co} + offset; - {operations} - }}", - struct_name = inner, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) -} - -/// Used to turn something like a u16be into -/// "let b0 = ((_self.packet[co + 0] as u16be) << 8) as u16be; -/// let b1 = ((_self.packet[co + 1] as u16be) as u16be; -/// b0 | b1" -fn generate_accessor_op_str(name: &str, ty: &str, operations: &[GetOperation]) -> String { - fn build_return(max: usize) -> String { - let mut ret = "".to_owned(); - for i in 0..max { - ret = ret + &format!("b{} | ", i)[..]; - } - let new_len = ret.len() - 3; - ret.truncate(new_len); - - ret - } - - let op_strings = if operations.len() == 1 { - let replacement_str = format!("({}[co] as {})", name, ty); - operations - .first() - .unwrap() - .to_string() - .replace("{}", &replacement_str[..]) - } else { - let mut op_strings = "".to_owned(); - for (idx, operation) in operations.iter().enumerate() { - let replacement_str = format!("({}[co + {}] as {})", name, idx, ty); - let operation = operation.to_string().replace("{}", &replacement_str[..]); - op_strings = op_strings + &format!("let b{} = ({}) as {};\n", idx, operation, ty)[..]; - } - op_strings = op_strings + &format!("\n{}\n", build_return(operations.len()))[..]; - - op_strings - }; - - op_strings -} - -#[test] -fn test_generate_accessor_op_str() { - { - let ops = operations(0, 24).unwrap(); - let result = generate_accessor_op_str("test", "u24be", &ops); - let expected = "let b0 = ((test[co + 0] as u24be) << 16) as u24be;\n\ - let b1 = ((test[co + 1] as u24be) << 8) as u24be;\n\ - let b2 = ((test[co + 2] as u24be)) as u24be;\n\n\ - b0 | b1 | b2\n"; - - assert_eq!(result, expected); - } - - { - let ops = operations(0, 16).unwrap(); - let result = generate_accessor_op_str("test", "u16be", &ops); - let expected = "let b0 = ((test[co + 0] as u16be) << 8) as u16be;\n\ - let b1 = ((test[co + 1] as u16be)) as u16be;\n\n\ - b0 | b1\n"; - assert_eq!(result, expected); - } - - { - let ops = operations(0, 8).unwrap(); - let result = generate_accessor_op_str("test", "u8", &ops); - let expected = "(test[co] as u8)"; - assert_eq!(result, expected); - } -} - -/// Given the name of a field, and a set of operations required to get the value of that field, -/// return the Rust code required to get the field. -#[inline] -fn generate_accessor_str( - name: &str, - ty: &str, - offset: &str, - operations: &[GetOperation], - inner: Option<&str>, -) -> String { - let op_strings = generate_accessor_op_str("_self.packet", ty, operations); - - let accessor = if let Some(struct_name) = inner { - format!( - "#[inline(always)] - #[allow(trivial_numeric_casts, unused_parens)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn get_{name}(_self: &{struct_name}) -> {ty} {{ - let co = {co}; - {operations} - }}", - struct_name = struct_name, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) - } else { - let comment = generate_accessor_or_mutator_comment(name, ty, AccessorMutator::Accessor); - format!( - "{comment} - #[inline] - #[allow(trivial_numeric_casts, unused_parens)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - pub fn get_{name}(&self) -> {ty} {{ - let _self = self; - let co = {co}; - {operations} - }}", - comment = comment, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) - }; - - accessor -} - -#[inline] -fn generate_accessor_with_offset_str( - name: &str, - ty: &str, - offset: &str, - operations: &[GetOperation], - inner: &str, -) -> String { - let op_strings = generate_accessor_op_str("_self.packet", ty, operations); - - format!( - "#[inline(always)] - #[allow(trivial_numeric_casts, unused_parens)] - #[cfg_attr(feature = \"clippy\", allow(used_underscore_binding))] - fn get_{name}(_self: &{struct_name}, offset: usize) -> {ty} {{ - let co = {co} + offset; - {operations} - }}", - struct_name = inner, - name = name, - ty = ty, - co = offset, - operations = op_strings - ) -} - -#[inline] -fn current_offset(bit_offset: usize, offset_fns: &[String]) -> String { - let base_offset = bit_offset / 8; - - offset_fns - .iter() - .fold(base_offset.to_string(), |a, b| a + " + " + &b[..]) -} - -#[inline] -fn generate_get_fields(packet: &Packet) -> String { - let mut gets = String::new(); - - for field in &packet.fields { - if field.is_payload { - gets = gets - + &format!( - "{field} : {{ - let payload = self.payload(); - let mut vec = Vec::with_capacity(payload.len()); - vec.extend_from_slice(payload); - - vec - }},\n", - field = field.name - )[..] - } else { - gets = gets + &format!("{field} : _self.get_{field}(),\n", field = field.name)[..] - } - } - - gets -} diff --git a/nex-macro/src/lib.rs b/nex-macro/src/lib.rs deleted file mode 100644 index f918867..0000000 --- a/nex-macro/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -#![deny(warnings)] - -use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, DeriveInput, Visibility}; - -mod decorator; -mod util; - -/// The entry point for the `derive(Packet)` custom derive -#[proc_macro_derive(Packet, attributes(construct_with, length, length_fn, payload))] -pub fn derive_packet(input: TokenStream) -> TokenStream { - let ast = parse_macro_input!(input as DeriveInput); - // ensure struct is public - match ast.vis { - Visibility::Public(_) => (), - _ => { - let ts = syn::Error::new(ast.ident.span(), "#[packet] structs must be public") - .to_compile_error(); - return ts.into(); - } - } - let name = &ast.ident; - let s = match &ast.data { - syn::Data::Struct(ref s) => decorator::generate_packet(s, name.to_string()), - _ => panic!("Only structs are supported"), - }; - match s { - Ok(ts) => ts.into(), - Err(e) => e.to_compile_error().into(), - } -} - -/// The entry point for the `packet` proc_macro_attribute -#[proc_macro_attribute] -pub fn packet(_attrs: TokenStream, code: TokenStream) -> TokenStream { - // let _attrs = parse_macro_input!(attrs as AttributeArgs); - let input = parse_macro_input!(code as DeriveInput); - // enhancement: if input already has Clone and/or Debug, do not add them - let s = quote! { - #[derive(::nex_macro::Packet, Clone, Debug)] - #[allow(unused_attributes)] - #input - }; - s.into() -} diff --git a/nex-macro/src/util.rs b/nex-macro/src/util.rs deleted file mode 100644 index 1200df0..0000000 --- a/nex-macro/src/util.rs +++ /dev/null @@ -1,1062 +0,0 @@ -//! Utility functions for bit manipulation operations - -use core::fmt; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Endianness { - Big, - Little, - Host, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct GetOperation { - mask: u8, - shiftl: u8, - shiftr: u8, -} - -impl fmt::Display for GetOperation { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - let should_mask = self.mask != 0xFF; - let shift = (self.shiftr as i16) - (self.shiftl as i16); - - let mask_str = if should_mask { - format!("({{}} & 0x{})", radix16_u8(self.mask)) - } else { - "{}".to_owned() - }; - - if shift == 0 { - write!(fmt, "{}", mask_str) - } else if shift < 0 { - write!(fmt, "{} << {}", mask_str, shift.abs()) - } else { - write!(fmt, "{} >> {}", mask_str, shift.abs()) - } - } -} - -#[test] -fn test_display_get_operation() { - type Op = GetOperation; - - assert_eq!( - Op { - mask: 0b00001111, - shiftl: 2, - shiftr: 0, - } - .to_string(), - "({} & 0xf) << 2" - ); - assert_eq!( - Op { - mask: 0b00001111, - shiftl: 2, - shiftr: 2, - } - .to_string(), - "({} & 0xf)" - ); - assert_eq!( - Op { - mask: 0b00001111, - shiftl: 0, - shiftr: 2, - } - .to_string(), - "({} & 0xf) >> 2" - ); - assert_eq!( - Op { - mask: 0b11111111, - shiftl: 0, - shiftr: 2, - } - .to_string(), - "{} >> 2" - ); - assert_eq!( - Op { - mask: 0b11111111, - shiftl: 3, - shiftr: 1, - } - .to_string(), - "{} << 2" - ); -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct SetOperation { - /// Bits to save from old byte - save_mask: u8, - /// Bits to mask out of value we're setting - value_mask: u64, - /// Number of places to left shift the value we're setting - shiftl: u8, - /// Number of places to right shift the value we're setting - shiftr: u8, -} - -macro_rules! radix_fn { - ($name:ident, $ty:ty) => { - fn $name(mut val: $ty) -> String { - let mut ret = String::new(); - let vals = "0123456789abcdef".as_bytes(); - while val > 0 { - let remainder = val % 16; - val /= 16; - ret = format!("{}{}", vals[remainder as usize] as char, ret); - } - - ret - } - - mod $name { - #[test] - fn test() { - assert_eq!(super::$name(0xab), "ab".to_owned()); - assert_eq!(super::$name(0x1c), "1c".to_owned()); - } - } - }; -} - -radix_fn!(radix16_u8, u8); -radix_fn!(radix16_u64, u64); - -impl fmt::Display for SetOperation { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - let should_mask = self.value_mask != 0xFF; - let should_save = self.save_mask != 0x00; - let shift = (self.shiftr as i16) - (self.shiftl as i16); - - let save_str = if should_save { - format!("({{packet}} & 0x{})", radix16_u8(self.save_mask)) - } else { - "".to_owned() - }; - - let mask_str = if should_mask { - format!("({{val}} & 0x{})", radix16_u64(self.value_mask)) - } else { - "{val}".to_owned() - }; - - let shift_str = if shift == 0 { - format!("{}", mask_str) - } else if shift < 0 { - format!("{} << {}", mask_str, shift.abs()) - } else { - format!("{} >> {}", mask_str, shift.abs()) - }; - - if should_save { - write!( - fmt, - "{{packet}} = ({} | ({}) as u8) as u8", - save_str, shift_str - ) - } else { - write!(fmt, "{{packet}} = ({}) as u8", shift_str) - } - } -} - -#[test] -fn test_display_set_operation() { - type Sop = SetOperation; - - assert_eq!( - Sop { - save_mask: 0b00000011, - value_mask: 0b00001111, - shiftl: 2, - shiftr: 0, - } - .to_string(), - "{packet} = (({packet} & 0x3) | (({val} & 0xf) << 2) as u8) as u8" - ); - assert_eq!( - Sop { - save_mask: 0b11000000, - value_mask: 0b00001111, - shiftl: 2, - shiftr: 2, - } - .to_string(), - "{packet} = (({packet} & 0xc0) | (({val} & 0xf)) as u8) as u8" - ); - assert_eq!( - Sop { - save_mask: 0b00011100, - value_mask: 0b00001111, - shiftl: 0, - shiftr: 2, - } - .to_string(), - "{packet} = (({packet} & 0x1c) | (({val} & 0xf) >> 2) as u8) as u8" - ); - assert_eq!( - Sop { - save_mask: 0b00000000, - value_mask: 0b11111111, - shiftl: 0, - shiftr: 2, - } - .to_string(), - "{packet} = ({val} >> 2) as u8" - ); - assert_eq!( - Sop { - save_mask: 0b00000011, - value_mask: 0b11111111, - shiftl: 3, - shiftr: 1, - } - .to_string(), - "{packet} = (({packet} & 0x3) | ({val} << 2) as u8) as u8" - ); -} - -/// Gets a mask to get bits_remaining bits from offset bits into a byte -/// If bits_remaining is > 8, it will be truncated as necessary -fn get_mask(offset: usize, bits_remaining: usize) -> (usize, u8) { - fn bits_remaining_in_byte(offset: usize, bits_remaining: usize) -> usize { - fn round_down(max_val: usize, val: usize) -> usize { - if val > max_val { - max_val - } else { - val - } - } - if (bits_remaining / 8) >= 1 { - 8 - offset - } else { - round_down(8 - offset, bits_remaining) - } - } - assert!(offset <= 7); - let mut num_bits_to_mask = bits_remaining_in_byte(offset, bits_remaining); - assert!(num_bits_to_mask <= 8 - offset); - let mut mask = 0; - while num_bits_to_mask > 0 { - mask = mask | (0x80 >> (offset + num_bits_to_mask - 1)); - num_bits_to_mask -= 1; - } - - (bits_remaining_in_byte(offset, bits_remaining), mask) -} - -#[test] -fn test_get_mask() { - assert_eq!(get_mask(0, 1), (1, 0b10000000)); - assert_eq!(get_mask(0, 2), (2, 0b11000000)); - assert_eq!(get_mask(0, 3), (3, 0b11100000)); - assert_eq!(get_mask(0, 4), (4, 0b11110000)); - assert_eq!(get_mask(0, 5), (5, 0b11111000)); - assert_eq!(get_mask(0, 6), (6, 0b11111100)); - assert_eq!(get_mask(0, 7), (7, 0b11111110)); - assert_eq!(get_mask(0, 8), (8, 0b11111111)); - assert_eq!(get_mask(0, 9), (8, 0b11111111)); - assert_eq!(get_mask(0, 100), (8, 0b11111111)); - - assert_eq!(get_mask(1, 1), (1, 0b01000000)); - assert_eq!(get_mask(1, 2), (2, 0b01100000)); - assert_eq!(get_mask(1, 3), (3, 0b01110000)); - assert_eq!(get_mask(1, 4), (4, 0b01111000)); - assert_eq!(get_mask(1, 5), (5, 0b01111100)); - assert_eq!(get_mask(1, 6), (6, 0b01111110)); - assert_eq!(get_mask(1, 7), (7, 0b01111111)); - assert_eq!(get_mask(1, 8), (7, 0b01111111)); - assert_eq!(get_mask(1, 9), (7, 0b01111111)); - assert_eq!(get_mask(1, 100), (7, 0b01111111)); - - assert_eq!(get_mask(5, 1), (1, 0b00000100)); - assert_eq!(get_mask(5, 2), (2, 0b00000110)); - assert_eq!(get_mask(5, 3), (3, 0b00000111)); - assert_eq!(get_mask(5, 4), (3, 0b00000111)); - assert_eq!(get_mask(5, 5), (3, 0b00000111)); - assert_eq!(get_mask(5, 6), (3, 0b00000111)); - assert_eq!(get_mask(5, 7), (3, 0b00000111)); - assert_eq!(get_mask(5, 8), (3, 0b00000111)); - assert_eq!(get_mask(5, 100), (3, 0b00000111)); -} - -fn get_shiftl(offset: usize, size: usize, byte_number: usize, num_bytes: usize) -> u8 { - if num_bytes == 1 || byte_number + 1 == num_bytes { - 0 - } else { - let base_shift = 8 - ((num_bytes * 8) - offset - size); - let bytes_to_shift = num_bytes - byte_number - 2; - - let ret = base_shift + (8 * bytes_to_shift); - - // (ret % 8) as u8 - ret as u8 - } -} - -#[test] -fn test_get_shiftl() { - assert_eq!(get_shiftl(0, 8, 0, 1), 0); - assert_eq!(get_shiftl(0, 9, 0, 2), 1); - assert_eq!(get_shiftl(0, 9, 1, 2), 0); - assert_eq!(get_shiftl(0, 10, 0, 2), 2); - assert_eq!(get_shiftl(0, 10, 1, 2), 0); - assert_eq!(get_shiftl(0, 11, 0, 2), 3); - assert_eq!(get_shiftl(0, 11, 1, 2), 0); - - assert_eq!(get_shiftl(1, 7, 0, 1), 0); - assert_eq!(get_shiftl(1, 8, 0, 2), 1); - assert_eq!(get_shiftl(1, 9, 0, 2), 2); - assert_eq!(get_shiftl(1, 9, 1, 2), 0); - assert_eq!(get_shiftl(1, 10, 0, 2), 3); - assert_eq!(get_shiftl(1, 10, 1, 2), 0); - assert_eq!(get_shiftl(1, 11, 0, 2), 4); - assert_eq!(get_shiftl(1, 11, 1, 2), 0); - - assert_eq!(get_shiftl(0, 35, 0, 5), 27); - assert_eq!(get_shiftl(0, 35, 1, 5), 19); - assert_eq!(get_shiftl(0, 35, 2, 5), 11); - assert_eq!(get_shiftl(0, 35, 3, 5), 3); - assert_eq!(get_shiftl(0, 35, 4, 5), 0); -} - -fn get_shiftr(offset: usize, size: usize, byte_number: usize, num_bytes: usize) -> u8 { - if byte_number + 1 == num_bytes { - ((num_bytes * 8) - offset - size) as u8 - } else { - 0 - } -} - -#[test] -fn test_get_shiftr() { - assert_eq!(get_shiftr(0, 1, 0, 1), 7); - assert_eq!(get_shiftr(0, 2, 0, 1), 6); - assert_eq!(get_shiftr(0, 3, 0, 1), 5); - assert_eq!(get_shiftr(0, 4, 0, 1), 4); - assert_eq!(get_shiftr(0, 5, 0, 1), 3); - assert_eq!(get_shiftr(0, 6, 0, 1), 2); - assert_eq!(get_shiftr(0, 7, 0, 1), 1); - assert_eq!(get_shiftr(0, 8, 0, 1), 0); - assert_eq!(get_shiftr(0, 9, 0, 2), 0); - assert_eq!(get_shiftr(0, 9, 1, 2), 7); - - assert_eq!(get_shiftr(1, 7, 0, 1), 0); - assert_eq!(get_shiftr(1, 8, 0, 2), 0); - assert_eq!(get_shiftr(1, 8, 1, 2), 7); - assert_eq!(get_shiftr(1, 9, 0, 2), 0); - assert_eq!(get_shiftr(1, 9, 1, 2), 6); - assert_eq!(get_shiftr(1, 10, 0, 2), 0); - assert_eq!(get_shiftr(1, 10, 1, 2), 5); - assert_eq!(get_shiftr(1, 11, 0, 2), 0); - assert_eq!(get_shiftr(1, 11, 1, 2), 4); - - assert_eq!(get_shiftr(0, 35, 0, 5), 0); - assert_eq!(get_shiftr(0, 35, 1, 5), 0); - assert_eq!(get_shiftr(0, 35, 2, 5), 0); - assert_eq!(get_shiftr(0, 35, 3, 5), 0); - assert_eq!(get_shiftr(0, 35, 4, 5), 5); -} - -/// Given an offset (number of bits into a chunk of memory), retrieve a list of operations to get -/// size bits. -/// -/// Assumes big endian, and that each byte will be masked, then cast to the next power of two -/// greater than or equal to size bits before shifting. offset should be in the range [0, 7] -pub fn operations(offset: usize, size: usize) -> Option> { - if offset > 7 || size == 0 || size > 64 { - return None; - } - - let start = offset / 8; - let end = (offset + size - 1) / 8; - let num_bytes = (end - start) + 1; - - let mut current_offset = offset; - let mut num_bits_remaining = size; - let mut ops = Vec::with_capacity(num_bytes); - for i in 0..num_bytes { - let (consumed, mask) = get_mask(current_offset, num_bits_remaining); - ops.push(GetOperation { - mask: mask, - shiftl: get_shiftl(offset, size, i, num_bytes), - shiftr: get_shiftr(offset, size, i, num_bytes), - }); - current_offset = 0; - if num_bits_remaining >= consumed { - num_bits_remaining -= consumed; - } - } - - Some(ops) -} - -#[test] -fn operations_test() { - type Op = GetOperation; - assert_eq!( - operations(0, 1).unwrap(), - vec![Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - }] - ); - assert_eq!( - operations(0, 2).unwrap(), - vec![Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - }] - ); - assert_eq!( - operations(0, 3).unwrap(), - vec![Op { - mask: 0b11100000, - shiftl: 0, - shiftr: 5, - }] - ); - assert_eq!( - operations(0, 4).unwrap(), - vec![Op { - mask: 0b11110000, - shiftl: 0, - shiftr: 4, - }] - ); - assert_eq!( - operations(0, 5).unwrap(), - vec![Op { - mask: 0b11111000, - shiftl: 0, - shiftr: 3, - }] - ); - assert_eq!( - operations(0, 6).unwrap(), - vec![Op { - mask: 0b11111100, - shiftl: 0, - shiftr: 2, - }] - ); - assert_eq!( - operations(0, 7).unwrap(), - vec![Op { - mask: 0b11111110, - shiftl: 0, - shiftr: 1, - }] - ); - assert_eq!( - operations(0, 8).unwrap(), - vec![Op { - mask: 0b11111111, - shiftl: 0, - shiftr: 0, - }] - ); - assert_eq!( - operations(0, 9).unwrap(), - vec![ - Op { - mask: 0b11111111, - shiftl: 1, - shiftr: 0, - }, - Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - } - ] - ); - assert_eq!( - operations(0, 10).unwrap(), - vec![ - Op { - mask: 0b11111111, - shiftl: 2, - shiftr: 0, - }, - Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - } - ] - ); - - assert_eq!( - operations(1, 1).unwrap(), - vec![Op { - mask: 0b01000000, - shiftl: 0, - shiftr: 6, - }] - ); - assert_eq!( - operations(1, 2).unwrap(), - vec![Op { - mask: 0b01100000, - shiftl: 0, - shiftr: 5, - }] - ); - assert_eq!( - operations(1, 3).unwrap(), - vec![Op { - mask: 0b01110000, - shiftl: 0, - shiftr: 4, - }] - ); - assert_eq!( - operations(1, 4).unwrap(), - vec![Op { - mask: 0b01111000, - shiftl: 0, - shiftr: 3, - }] - ); - assert_eq!( - operations(1, 5).unwrap(), - vec![Op { - mask: 0b01111100, - shiftl: 0, - shiftr: 2, - }] - ); - assert_eq!( - operations(1, 6).unwrap(), - vec![Op { - mask: 0b01111110, - shiftl: 0, - shiftr: 1, - }] - ); - assert_eq!( - operations(1, 7).unwrap(), - vec![Op { - mask: 0b01111111, - shiftl: 0, - shiftr: 0, - }] - ); - assert_eq!( - operations(1, 8).unwrap(), - vec![ - Op { - mask: 0b01111111, - shiftl: 1, - shiftr: 0, - }, - Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - } - ] - ); - assert_eq!( - operations(1, 9).unwrap(), - vec![ - Op { - mask: 0b01111111, - shiftl: 2, - shiftr: 0, - }, - Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - } - ] - ); - - assert_eq!(operations(8, 1), None); - assert_eq!(operations(3, 0), None); - assert_eq!(operations(3, 65), None); - - assert_eq!( - operations(3, 33).unwrap(), - vec![ - Op { - mask: 0b00011111, - shiftl: 28, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 20, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 12, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 4, - shiftr: 0, - }, - Op { - mask: 0b11110000, - shiftl: 0, - shiftr: 4, - } - ] - ); - - assert_eq!( - operations(6, 6).unwrap(), - vec![ - Op { - mask: 3, - shiftl: 4, - shiftr: 0, - }, - Op { - mask: 240, - shiftl: 0, - shiftr: 4, - } - ] - ); -} - -/// Mask `bits` bits of a byte. eg. mask_high_bits(2) == 0b00000011 -fn mask_high_bits(mut bits: u64) -> u64 { - let mut mask = 0; - while bits > 0 { - mask = mask | (1 << (bits - 1)); - bits -= 1; - } - - mask -} - -/// Converts a set of operations which would get a field, to a set of operations which would set -/// the field -/// -/// In the form of (bits to get, bits to set) -pub fn to_mutator(ops: &[GetOperation]) -> Vec { - fn num_bits_set(n: u8) -> u64 { - let mut count = 0; - for i in 0..8 { - if n & (1 << i) > 0 { - count += 1; - } - } - - count - } - - let mut sops = Vec::with_capacity(ops.len()); - for op in ops { - sops.push(SetOperation { - save_mask: !op.mask, - value_mask: mask_high_bits(num_bits_set(op.mask)) << op.shiftl, - shiftl: op.shiftr, - shiftr: op.shiftl, - }); - } - - sops -} - -#[test] -fn test_to_mutator() { - type Op = GetOperation; - type Sop = SetOperation; - - assert_eq!( - to_mutator(&[Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - }]), - vec![Sop { - save_mask: 0b01111111, - value_mask: 0b00000001, - shiftl: 7, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - }]), - vec![Sop { - save_mask: 0b00111111, - value_mask: 0b00000011, - shiftl: 6, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11100000, - shiftl: 0, - shiftr: 5, - }]), - vec![Sop { - save_mask: 0b00011111, - value_mask: 0b00000111, - shiftl: 5, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11110000, - shiftl: 0, - shiftr: 4, - }]), - vec![Sop { - save_mask: 0b00001111, - value_mask: 0b00001111, - shiftl: 4, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11111000, - shiftl: 0, - shiftr: 3, - }]), - vec![Sop { - save_mask: 0b00000111, - value_mask: 0b00011111, - shiftl: 3, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11111100, - shiftl: 0, - shiftr: 2, - }]), - vec![Sop { - save_mask: 0b00000011, - value_mask: 0b00111111, - shiftl: 2, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11111110, - shiftl: 0, - shiftr: 1, - }]), - vec![Sop { - save_mask: 0b00000001, - value_mask: 0b01111111, - shiftl: 1, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b11111111, - shiftl: 0, - shiftr: 0, - }]), - vec![Sop { - save_mask: 0b00000000, - value_mask: 0b11111111, - shiftl: 0, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[ - Op { - mask: 0b11111111, - shiftl: 1, - shiftr: 0, - }, - Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - } - ]), - vec![ - Sop { - save_mask: 0b00000000, - value_mask: 0b111111110, - shiftl: 0, - shiftr: 1, - }, - Sop { - save_mask: 0b01111111, - value_mask: 0b00000001, - shiftl: 7, - shiftr: 0, - } - ] - ); - - assert_eq!( - to_mutator(&[ - Op { - mask: 0b11111111, - shiftl: 2, - shiftr: 0, - }, - Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - } - ]), - vec![ - Sop { - save_mask: 0b00000000, - value_mask: 0b1111111100, - shiftl: 0, - shiftr: 2, - }, - Sop { - save_mask: 0b00111111, - value_mask: 0b00000011, - shiftl: 6, - shiftr: 0, - } - ] - ); - - assert_eq!( - to_mutator(&[Op { - mask: 0b01000000, - shiftl: 0, - shiftr: 6, - }]), - vec![Sop { - save_mask: 0b10111111, - value_mask: 0b00000001, - shiftl: 6, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01100000, - shiftl: 0, - shiftr: 5, - }]), - vec![Sop { - save_mask: 0b10011111, - value_mask: 0b00000011, - shiftl: 5, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01110000, - shiftl: 0, - shiftr: 4, - }]), - vec![Sop { - save_mask: 0b10001111, - value_mask: 0b00000111, - shiftl: 4, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01111000, - shiftl: 0, - shiftr: 3, - }]), - vec![Sop { - save_mask: 0b10000111, - value_mask: 0b00001111, - shiftl: 3, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01111100, - shiftl: 0, - shiftr: 2, - }]), - vec![Sop { - save_mask: 0b10000011, - value_mask: 0b00011111, - shiftl: 2, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01111110, - shiftl: 0, - shiftr: 1, - }]), - vec![Sop { - save_mask: 0b10000001, - value_mask: 0b00111111, - shiftl: 1, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[Op { - mask: 0b01111111, - shiftl: 0, - shiftr: 0, - }]), - vec![Sop { - save_mask: 0b10000000, - value_mask: 0b01111111, - shiftl: 0, - shiftr: 0, - }] - ); - assert_eq!( - to_mutator(&[ - Op { - mask: 0b01111111, - shiftl: 1, - shiftr: 0, - }, - Op { - mask: 0b10000000, - shiftl: 0, - shiftr: 7, - } - ]), - vec![ - Sop { - save_mask: 0b10000000, - value_mask: 0b11111110, - shiftl: 0, - shiftr: 1, - }, - Sop { - save_mask: 0b01111111, - value_mask: 0b00000001, - shiftl: 7, - shiftr: 0, - } - ] - ); - assert_eq!( - to_mutator(&[ - Op { - mask: 0b01111111, - shiftl: 2, - shiftr: 0, - }, - Op { - mask: 0b11000000, - shiftl: 0, - shiftr: 6, - } - ]), - vec![ - Sop { - save_mask: 0b10000000, - value_mask: 0b0111111100, - shiftl: 0, - shiftr: 2, - }, - Sop { - save_mask: 0b00111111, - value_mask: 0b00000011, - shiftl: 6, - shiftr: 0, - } - ] - ); - - assert_eq!( - to_mutator(&[ - Op { - mask: 0b00011111, - shiftl: 28, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 20, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 12, - shiftr: 0, - }, - Op { - mask: 0b11111111, - shiftl: 4, - shiftr: 0, - }, - Op { - mask: 0b11110000, - shiftl: 0, - shiftr: 4, - } - ]), - vec![ - Sop { - save_mask: 0b11100000, - value_mask: 0x1F0000000, - shiftl: 0, - shiftr: 28, - }, - Sop { - save_mask: 0b00000000, - value_mask: 0x00FF00000, - shiftl: 0, - shiftr: 20, - }, - Sop { - save_mask: 0b00000000, - value_mask: 0x0000FF000, - shiftl: 0, - shiftr: 12, - }, - Sop { - save_mask: 0b00000000, - value_mask: 0x000000FF0, - shiftl: 0, - shiftr: 4, - }, - Sop { - save_mask: 0b00001111, - value_mask: 0x00000000F, - shiftl: 4, - shiftr: 0, - } - ] - ); -} - -/// Takes a set of operations to get a field in big endian, and converts them to get the field in -/// little endian. -pub fn to_little_endian(_ops: Vec) -> Vec { - let mut ops = _ops.clone(); - for (op, be_op) in ops.iter_mut().zip(_ops.iter().rev()) { - op.shiftl = be_op.shiftl; - } - ops -} diff --git a/nex-packet-builder/Cargo.toml b/nex-packet-builder/Cargo.toml deleted file mode 100644 index c528bed..0000000 --- a/nex-packet-builder/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "nex-packet-builder" -version.workspace = true -edition.workspace = true -authors.workspace = true -description = "Provides high-level packet building on top of nex-packet. Part of nex project. " -repository = "https://github.com/shellrow/nex" -readme = "../README.md" -keywords = ["network", "packet"] -categories = ["network-programming"] -license = "MIT" - -[dependencies] -nex-core = { workspace = true } -nex-packet = { workspace = true } -rand = { workspace = true } diff --git a/nex-packet-builder/src/arp.rs b/nex-packet-builder/src/arp.rs deleted file mode 100644 index 52e1b75..0000000 --- a/nex-packet-builder/src/arp.rs +++ /dev/null @@ -1,67 +0,0 @@ -use nex_core::mac::MacAddr; -use nex_packet::arp::ArpHardwareType; -use nex_packet::arp::ArpOperation; -use nex_packet::arp::MutableArpPacket; -use nex_packet::arp::ARP_HEADER_LEN; -use nex_packet::ethernet::EtherType; -use nex_packet::Packet; -use std::net::Ipv4Addr; - -/// Build ARP packet. -pub(crate) fn build_arp_packet( - arp_packet: &mut MutableArpPacket, - src_mac: MacAddr, - dst_mac: MacAddr, - src_ip: Ipv4Addr, - dst_ip: Ipv4Addr, -) { - arp_packet.set_hardware_type(ArpHardwareType::Ethernet); - arp_packet.set_protocol_type(EtherType::Ipv4); - arp_packet.set_hw_addr_len(6); - arp_packet.set_proto_addr_len(4); - arp_packet.set_operation(ArpOperation::Request); - arp_packet.set_sender_hw_addr(src_mac); - arp_packet.set_sender_proto_addr(src_ip); - arp_packet.set_target_hw_addr(dst_mac); - arp_packet.set_target_proto_addr(dst_ip); -} - -/// ARP Packet Builder. -#[derive(Clone, Debug)] -pub struct ArpPacketBuilder { - /// Source MAC address. - pub src_mac: MacAddr, - /// Destination MAC address. - pub dst_mac: MacAddr, - /// Source IPv4 address. - pub src_ip: Ipv4Addr, - /// Destination IPv4 address. - pub dst_ip: Ipv4Addr, -} - -impl ArpPacketBuilder { - /// Constructs a new ArpPacketBuilder. - pub fn new() -> ArpPacketBuilder { - ArpPacketBuilder { - src_mac: MacAddr::zero(), - dst_mac: MacAddr::broadcast(), - src_ip: Ipv4Addr::UNSPECIFIED, - dst_ip: Ipv4Addr::UNSPECIFIED, - } - } - /// Builds ARP packet and return bytes. - pub fn build(&self) -> Vec { - let mut buffer = [0u8; ARP_HEADER_LEN]; - let mut arp_packet = MutableArpPacket::new(&mut buffer).unwrap(); - arp_packet.set_hardware_type(ArpHardwareType::Ethernet); - arp_packet.set_protocol_type(EtherType::Ipv4); - arp_packet.set_hw_addr_len(6); - arp_packet.set_proto_addr_len(4); - arp_packet.set_operation(ArpOperation::Request); - arp_packet.set_sender_hw_addr(self.src_mac); - arp_packet.set_sender_proto_addr(self.src_ip); - arp_packet.set_target_hw_addr(self.dst_mac); - arp_packet.set_target_proto_addr(self.dst_ip); - arp_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/builder.rs b/nex-packet-builder/src/builder.rs deleted file mode 100644 index 8792ee9..0000000 --- a/nex-packet-builder/src/builder.rs +++ /dev/null @@ -1,174 +0,0 @@ -use nex_packet::ethernet::ETHERNET_HEADER_LEN; -use nex_packet::ipv4::IPV4_HEADER_LEN; -use nex_packet::ipv6::IPV6_HEADER_LEN; -use nex_packet::udp::UDP_HEADER_LEN; - -use crate::arp::ArpPacketBuilder; -use crate::dhcp::DhcpPacketBuilder; -use crate::ethernet::EthernetPacketBuilder; -use crate::icmp::IcmpPacketBuilder; -use crate::icmpv6::Icmpv6PacketBuilder; -use crate::ipv4::Ipv4PacketBuilder; -use crate::ipv6::Ipv6PacketBuilder; -use crate::ndp::NdpPacketBuilder; -use crate::tcp::TcpPacketBuilder; -use crate::udp::UdpPacketBuilder; - -/// Packet builder for building full packet. -#[derive(Clone, Debug)] -pub struct PacketBuilder { - packet: Vec, -} - -impl PacketBuilder { - /// Constructs a new PacketBuilder. - pub fn new() -> Self { - PacketBuilder { packet: Vec::new() } - } - /// Return packet bytes. - pub fn packet(&self) -> Vec { - self.packet.clone() - } - /// Retern IP packet bytes (without ethernet header). - pub fn ip_packet(&self) -> Vec { - if self.packet.len() < ETHERNET_HEADER_LEN { - return Vec::new(); - } - self.packet[ETHERNET_HEADER_LEN..].to_vec() - } - /// Set ethernet header. - pub fn set_ethernet(&mut self, packet_builder: EthernetPacketBuilder) { - if self.packet.len() < ETHERNET_HEADER_LEN { - self.packet.resize(ETHERNET_HEADER_LEN, 0); - } - self.packet[0..ETHERNET_HEADER_LEN].copy_from_slice(&packet_builder.build()); - } - /// Set arp header. - pub fn set_arp(&mut self, packet_builder: ArpPacketBuilder) { - let arp_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + arp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + arp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN..ETHERNET_HEADER_LEN + arp_packet.len()] - .copy_from_slice(&arp_packet); - } - /// Set IPv4 header. - pub fn set_ipv4(&mut self, packet_builder: Ipv4PacketBuilder) { - let ipv4_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + ipv4_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + ipv4_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN..ETHERNET_HEADER_LEN + ipv4_packet.len()] - .copy_from_slice(&ipv4_packet); - } - /// Set IPv6 header. - pub fn set_ipv6(&mut self, packet_builder: Ipv6PacketBuilder) { - let ipv6_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + ipv6_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + ipv6_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN..ETHERNET_HEADER_LEN + ipv6_packet.len()] - .copy_from_slice(&ipv6_packet); - } - /// Set ICMP header. - pub fn set_icmp(&mut self, packet_builder: IcmpPacketBuilder) { - let icmp_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + icmp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + icmp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV4_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + icmp_packet.len()] - .copy_from_slice(&icmp_packet); - } - /// Set ICMPv6 header. - pub fn set_icmpv6(&mut self, packet_builder: Icmpv6PacketBuilder) { - let icmpv6_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + icmpv6_packet.len() { - self.packet.resize( - ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + icmpv6_packet.len(), - 0, - ); - } - self.packet[ETHERNET_HEADER_LEN + IPV6_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + icmpv6_packet.len()] - .copy_from_slice(&icmpv6_packet); - } - /// Set NDP header. - pub fn set_ndp(&mut self, packet_builder: NdpPacketBuilder) { - let ndp_packet = packet_builder.build(); - if self.packet.len() < ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + ndp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + ndp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV6_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + ndp_packet.len()] - .copy_from_slice(&ndp_packet); - } - /// Set TCP header and payload. - pub fn set_tcp(&mut self, packet_builder: TcpPacketBuilder) { - let tcp_packet = packet_builder.build(); - if packet_builder.dst_ip.is_ipv4() { - if self.packet.len() < ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV4_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + tcp_packet.len()] - .copy_from_slice(&tcp_packet); - } else if packet_builder.dst_ip.is_ipv6() { - if self.packet.len() < ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + tcp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + tcp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV6_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + tcp_packet.len()] - .copy_from_slice(&tcp_packet); - } - } - /// Set UDP header and payload. - pub fn set_udp(&mut self, packet_builder: UdpPacketBuilder) { - let udp_packet = packet_builder.build(); - if packet_builder.dst_ip.is_ipv4() { - if self.packet.len() < ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + udp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + udp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV4_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + udp_packet.len()] - .copy_from_slice(&udp_packet); - } else if packet_builder.dst_ip.is_ipv6() { - if self.packet.len() < ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + udp_packet.len() { - self.packet - .resize(ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + udp_packet.len(), 0); - } - self.packet[ETHERNET_HEADER_LEN + IPV6_HEADER_LEN - ..ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + udp_packet.len()] - .copy_from_slice(&udp_packet); - } - } - /// Set DHCP header and payload. - pub fn set_dhcp(&mut self, packet_builder: DhcpPacketBuilder) { - let dhcp_packet = packet_builder.build(); - - let min_offset_ipv4 = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN; - let min_offset_ipv6 = ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN; - - if self.packet.len() >= min_offset_ipv4 { - if self.packet.len() < min_offset_ipv4 + dhcp_packet.len() { - self.packet.resize(min_offset_ipv4 + dhcp_packet.len(), 0); - } - self.packet[min_offset_ipv4..min_offset_ipv4 + dhcp_packet.len()] - .copy_from_slice(&dhcp_packet); - } else if self.packet.len() >= min_offset_ipv6 { - if self.packet.len() < min_offset_ipv6 + dhcp_packet.len() { - self.packet.resize(min_offset_ipv6 + dhcp_packet.len(), 0); - } - self.packet[min_offset_ipv6..min_offset_ipv6 + dhcp_packet.len()] - .copy_from_slice(&dhcp_packet); - } - } -} diff --git a/nex-packet-builder/src/dhcp.rs b/nex-packet-builder/src/dhcp.rs deleted file mode 100644 index 1ba0fd2..0000000 --- a/nex-packet-builder/src/dhcp.rs +++ /dev/null @@ -1,163 +0,0 @@ -use nex_core::mac::MacAddr; -use nex_packet::dhcp::DHCP_MIN_PACKET_SIZE; -use nex_packet::dhcp::{DhcpHardwareType, DhcpOperation, MutableDhcpPacket}; -use nex_packet::Packet; -use std::net::Ipv4Addr; - -#[derive(Clone, Debug)] -pub struct DhcpPacketBuilder { - pub operation: DhcpOperation, - pub htype: DhcpHardwareType, - pub hlen: u8, - pub hops: u8, - pub xid: u32, - pub secs: u16, - pub flags: u16, - pub ciaddr: Option, - pub yiaddr: Option, - pub siaddr: Option, - pub giaddr: Option, - pub chaddr: MacAddr, - pub options: Vec, -} - -impl DhcpPacketBuilder { - pub fn new(transaction_id: u32, client_mac: MacAddr) -> Self { - Self { - operation: DhcpOperation::Request, - htype: DhcpHardwareType::Ethernet, - hlen: 6, - hops: 0, - xid: transaction_id, - secs: 0, - flags: 0, - ciaddr: None, - yiaddr: None, - siaddr: None, - giaddr: None, - chaddr: client_mac, - options: Vec::new(), - } - } - - /// Set DHCPDISCOVER options - pub fn set_discover_options(&mut self) { - self.operation = DhcpOperation::Request; - self.options.clear(); - self.options.extend_from_slice(&[ - 53, 1, 1, // DHCP Message Type: DHCPDISCOVER (1) - 55, 2, 1, 3, // Parameter Request List: Subnet Mask (1), Router (3) - 255, // End - ]); - } - - /// Set DHCPDISCOVER options with builder pattern - pub fn with_discover_options(mut self) -> Self { - self.set_discover_options(); - self - } - - /// Set DHCPREQUEST options - pub fn set_request_options(&mut self, requested_ip: Ipv4Addr, server_id: Ipv4Addr) { - self.operation = DhcpOperation::Request; - self.options.clear(); - self.options.extend_from_slice(&[ - 53, - 1, - 3, // DHCP Message Type: DHCPREQUEST (3) - 50, - 4, // Requested IP Address - requested_ip.octets()[0], - requested_ip.octets()[1], - requested_ip.octets()[2], - requested_ip.octets()[3], - 54, - 4, // DHCP Server Identifier - server_id.octets()[0], - server_id.octets()[1], - server_id.octets()[2], - server_id.octets()[3], - 55, - 2, - 1, - 3, // Parameter Request List - 255, // End - ]); - } - - /// Set DHCPREQUEST options with builder pattern - pub fn with_request_options(mut self, requested_ip: Ipv4Addr, server_id: Ipv4Addr) -> Self { - self.set_request_options(requested_ip, server_id); - self - } - - pub fn build(&self) -> Vec { - let mut buffer = vec![0u8; DHCP_MIN_PACKET_SIZE + self.options.len()]; - let mut dhcp_packet = MutableDhcpPacket::new(&mut buffer).unwrap(); - - dhcp_packet.set_op(self.operation); - dhcp_packet.set_htype(self.htype); - dhcp_packet.set_hlen(self.hlen); - dhcp_packet.set_hops(self.hops); - dhcp_packet.set_xid(self.xid); - dhcp_packet.set_secs(self.secs); - dhcp_packet.set_flags(self.flags); - dhcp_packet.set_ciaddr(self.ciaddr.unwrap_or(Ipv4Addr::new(0, 0, 0, 0))); - dhcp_packet.set_yiaddr(self.yiaddr.unwrap_or(Ipv4Addr::new(0, 0, 0, 0))); - dhcp_packet.set_siaddr(self.siaddr.unwrap_or(Ipv4Addr::new(0, 0, 0, 0))); - dhcp_packet.set_giaddr(self.giaddr.unwrap_or(Ipv4Addr::new(0, 0, 0, 0))); - dhcp_packet.set_chaddr(self.chaddr); - - dhcp_packet.set_chaddr_pad(&[0u8; 10]); - dhcp_packet.set_sname(&[0u8; 64]); - dhcp_packet.set_file(&[0u8; 128]); - - dhcp_packet.set_options(&self.options); - - dhcp_packet.packet().to_vec() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use nex_core::mac::MacAddr; - use std::net::Ipv4Addr; - - #[test] - fn test_dhcp_discover_builder() { - let transaction_id = 0x12345678; - let client_mac = MacAddr::new(0x00, 0x11, 0x22, 0x33, 0x44, 0x55); - let builder = DhcpPacketBuilder::new(transaction_id, client_mac).with_discover_options(); - let packet = builder.build(); - - assert!(packet.len() >= DHCP_MIN_PACKET_SIZE); - assert_eq!(packet[0], 1); - assert_eq!( - u32::from_be_bytes([packet[4], packet[5], packet[6], packet[7]]), - transaction_id - ); - assert_eq!(&packet[28..34], &client_mac.octets()); - } - - #[test] - fn test_dhcp_request_builder() { - let transaction_id = 0x87654321; - let client_mac = MacAddr::new(0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff); - let requested_ip = Ipv4Addr::new(192, 168, 1, 100); - let server_id = Ipv4Addr::new(192, 168, 1, 1); - let builder = DhcpPacketBuilder::new(transaction_id, client_mac) - .with_request_options(requested_ip, server_id); - let packet = builder.build(); - - assert!(packet.len() >= DHCP_MIN_PACKET_SIZE); - assert_eq!(packet[0], 1); - assert_eq!( - u32::from_be_bytes([packet[4], packet[5], packet[6], packet[7]]), - transaction_id - ); - assert_eq!(&packet[28..34], &client_mac.octets()); - assert_eq!(packet[DHCP_MIN_PACKET_SIZE], 53); - assert_eq!(packet[DHCP_MIN_PACKET_SIZE + 2], 3); - } -} diff --git a/nex-packet-builder/src/ethernet.rs b/nex-packet-builder/src/ethernet.rs deleted file mode 100644 index 5e203a4..0000000 --- a/nex-packet-builder/src/ethernet.rs +++ /dev/null @@ -1,95 +0,0 @@ -use nex_core::mac::MacAddr; -use nex_packet::ethernet::{EtherType, MutableEthernetPacket, ETHERNET_HEADER_LEN}; -use nex_packet::ipv4::Ipv4Packet; -use nex_packet::Packet; - -/// Build Ethernet packet. -pub(crate) fn build_ethernet_packet( - eth_packet: &mut MutableEthernetPacket, - src_mac: MacAddr, - dst_mac: MacAddr, - ether_type: EtherType, -) { - eth_packet.set_source(src_mac); - eth_packet.set_destination(dst_mac); - match ether_type { - EtherType::Arp => { - eth_packet.set_ethertype(EtherType::Arp); - } - EtherType::Ipv4 => { - eth_packet.set_ethertype(EtherType::Ipv4); - } - EtherType::Ipv6 => { - eth_packet.set_ethertype(EtherType::Ipv6); - } - _ => { - // TODO - } - } -} - -/// Build Ethernet ARP packet. -pub(crate) fn build_ethernet_arp_packet(eth_packet: &mut MutableEthernetPacket, src_mac: MacAddr) { - eth_packet.set_source(src_mac); - eth_packet.set_destination(MacAddr::broadcast()); - eth_packet.set_ethertype(EtherType::Arp); -} - -/// Ethernet Packet Builder. -#[derive(Clone, Debug)] -pub struct EthernetPacketBuilder { - /// Source MAC address. - pub src_mac: MacAddr, - /// Destination MAC address. - pub dst_mac: MacAddr, - /// EtherType. - pub ether_type: EtherType, -} - -impl EthernetPacketBuilder { - /// Constructs a new EthernetPacketBuilder. - pub fn new() -> EthernetPacketBuilder { - EthernetPacketBuilder { - src_mac: MacAddr::zero(), - dst_mac: MacAddr::zero(), - ether_type: EtherType::Ipv4, - } - } - /// Build Ethernet packet and return bytes. - pub fn build(&self) -> Vec { - let mut buffer: Vec = vec![0; ETHERNET_HEADER_LEN]; - let mut eth_packet = MutableEthernetPacket::new(&mut buffer).unwrap(); - build_ethernet_packet( - &mut eth_packet, - self.src_mac.clone(), - self.dst_mac.clone(), - self.ether_type, - ); - eth_packet.to_immutable().packet().to_vec() - } -} - -/// Create Dummy Ethernet Frame. -#[allow(dead_code)] -pub(crate) fn create_dummy_ethernet_frame(payload_offset: usize, packet: &[u8]) -> Vec { - if packet.len() <= payload_offset { - return packet.to_vec(); - } - let buffer_size: usize = packet.len() + ETHERNET_HEADER_LEN - payload_offset; - let mut buffer: Vec = vec![0; buffer_size]; - let src_mac: MacAddr = MacAddr::zero(); - let dst_mac: MacAddr = MacAddr::zero(); - let mut ether_type: EtherType = EtherType::Unknown(0); - let mut eth_packet = MutableEthernetPacket::new(&mut buffer).unwrap(); - if let Some(ip_packet) = Ipv4Packet::new(&packet[payload_offset..]) { - let version = ip_packet.get_version(); - if version == 4 { - ether_type = EtherType::Ipv4; - } else if version == 6 { - ether_type = EtherType::Ipv6; - } - } - build_ethernet_packet(&mut eth_packet, src_mac, dst_mac, ether_type); - eth_packet.set_payload(&packet[payload_offset..]); - eth_packet.to_immutable().packet().to_vec() -} diff --git a/nex-packet-builder/src/icmp.rs b/nex-packet-builder/src/icmp.rs deleted file mode 100644 index 934a522..0000000 --- a/nex-packet-builder/src/icmp.rs +++ /dev/null @@ -1,61 +0,0 @@ -use nex_packet::icmp::echo_request::MutableEchoRequestPacket; -use nex_packet::icmp::IcmpType; -use nex_packet::icmp::ICMPV4_HEADER_LEN; -use nex_packet::Packet; -use std::net::Ipv4Addr; - -/// Build ICMP packet. -pub(crate) fn build_icmp_echo_packet(icmp_packet: &mut MutableEchoRequestPacket) { - icmp_packet.set_icmp_type(IcmpType::EchoRequest); - icmp_packet.set_sequence_number(rand::random::()); - icmp_packet.set_identifier(rand::random::()); - let icmp_check_sum = nex_packet::util::checksum(&icmp_packet.packet(), 1); - icmp_packet.set_checksum(icmp_check_sum); -} - -/// ICMP Packet Builder. -#[derive(Clone, Debug)] -pub struct IcmpPacketBuilder { - /// Source IPv4 address. - pub src_ip: Ipv4Addr, - /// Destination IPv4 address. - pub dst_ip: Ipv4Addr, - /// ICMP type. - pub icmp_type: IcmpType, - /// ICMP sequence number. - pub sequence_number: Option, - /// ICMP identifier. - pub identifier: Option, -} - -impl IcmpPacketBuilder { - /// Constructs a new IcmpPacketBuilder. - pub fn new(src_ip: Ipv4Addr, dst_ip: Ipv4Addr) -> IcmpPacketBuilder { - IcmpPacketBuilder { - src_ip: src_ip, - dst_ip: dst_ip, - icmp_type: IcmpType::EchoRequest, - sequence_number: None, - identifier: None, - } - } - /// Build ICMP packet and return bytes. - pub fn build(&self) -> Vec { - let buffer: &mut [u8] = &mut [0u8; ICMPV4_HEADER_LEN]; - let mut icmp_packet = MutableEchoRequestPacket::new(buffer).unwrap(); - icmp_packet.set_icmp_type(self.icmp_type); - if let Some(sequence_number) = self.sequence_number { - icmp_packet.set_sequence_number(sequence_number); - } else { - icmp_packet.set_sequence_number(rand::random::()); - } - if let Some(identifier) = self.identifier { - icmp_packet.set_identifier(identifier); - } else { - icmp_packet.set_identifier(rand::random::()); - } - let icmp_check_sum = nex_packet::util::checksum(&icmp_packet.packet(), 1); - icmp_packet.set_checksum(icmp_check_sum); - icmp_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/icmpv6.rs b/nex-packet-builder/src/icmpv6.rs deleted file mode 100644 index 91edf59..0000000 --- a/nex-packet-builder/src/icmpv6.rs +++ /dev/null @@ -1,62 +0,0 @@ -use nex_packet::icmpv6::echo_request::MutableEchoRequestPacket; -use nex_packet::icmpv6::Icmpv6Packet; -use nex_packet::icmpv6::Icmpv6Type; -use nex_packet::icmpv6::ICMPV6_HEADER_LEN; -use nex_packet::Packet; -use std::net::Ipv6Addr; - -/// Build ICMPv6 packet. -pub(crate) fn build_icmpv6_echo_packet( - icmp_packet: &mut MutableEchoRequestPacket, - src_ip: Ipv6Addr, - dst_ip: Ipv6Addr, -) { - icmp_packet.set_icmpv6_type(Icmpv6Type::EchoRequest); - icmp_packet.set_identifier(rand::random::()); - icmp_packet.set_sequence_number(rand::random::()); - let icmpv6_packet = Icmpv6Packet::new(icmp_packet.packet()).unwrap(); - let icmpv6_checksum = nex_packet::icmpv6::checksum(&icmpv6_packet, &src_ip, &dst_ip); - //let icmp_check_sum = pnet::packet::util::checksum(&icmp_packet.packet(), 1); - icmp_packet.set_checksum(icmpv6_checksum); -} - -/// ICMPv6 Packet Builder. -#[derive(Clone, Debug)] -pub struct Icmpv6PacketBuilder { - /// Source IPv6 address. - pub src_ip: Ipv6Addr, - /// Destination IPv6 address. - pub dst_ip: Ipv6Addr, - /// ICMPv6 type. - pub icmpv6_type: Icmpv6Type, - /// ICMPv6 sequence number. - pub sequence_number: Option, - /// ICMPv6 identifier. - pub identifier: Option, -} - -impl Icmpv6PacketBuilder { - /// Constructs a new Icmpv6PacketBuilder. - pub fn new(src_ip: Ipv6Addr, dst_ip: Ipv6Addr) -> Icmpv6PacketBuilder { - Icmpv6PacketBuilder { - src_ip, - dst_ip, - icmpv6_type: Icmpv6Type::EchoRequest, - sequence_number: None, - identifier: None, - } - } - /// Build ICMPv6 packet and return bytes. - pub fn build(&self) -> Vec { - let buffer: &mut [u8] = &mut [0u8; ICMPV6_HEADER_LEN]; - let mut icmp_packet = MutableEchoRequestPacket::new(buffer).unwrap(); - icmp_packet.set_icmpv6_type(self.icmpv6_type); - icmp_packet.set_identifier(self.identifier.unwrap_or(rand::random::())); - icmp_packet.set_sequence_number(self.sequence_number.unwrap_or(rand::random::())); - let icmpv6_packet = Icmpv6Packet::new(icmp_packet.packet()).unwrap(); - let icmpv6_checksum = - nex_packet::icmpv6::checksum(&icmpv6_packet, &self.src_ip, &self.dst_ip); - icmp_packet.set_checksum(icmpv6_checksum); - icmp_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/ipv4.rs b/nex-packet-builder/src/ipv4.rs deleted file mode 100644 index 4f08dd1..0000000 --- a/nex-packet-builder/src/ipv4.rs +++ /dev/null @@ -1,119 +0,0 @@ -use nex_packet::ip::IpNextLevelProtocol; -use nex_packet::ipv4::Ipv4Flags; -use nex_packet::ipv4::MutableIpv4Packet; -use nex_packet::ipv4::IPV4_HEADER_LEN; -use nex_packet::ipv4::IPV4_HEADER_LENGTH_BYTE_UNITS; -use nex_packet::Packet; -use std::net::Ipv4Addr; - -/// Build IPv4 packet. -pub(crate) fn build_ipv4_packet( - ipv4_packet: &mut MutableIpv4Packet, - src_ip: Ipv4Addr, - dst_ip: Ipv4Addr, - next_protocol: IpNextLevelProtocol, -) { - ipv4_packet.set_header_length((IPV4_HEADER_LEN / IPV4_HEADER_LENGTH_BYTE_UNITS) as u8); - ipv4_packet.set_source(src_ip); - ipv4_packet.set_destination(dst_ip); - ipv4_packet.set_identification(rand::random::()); - ipv4_packet.set_ttl(64); - ipv4_packet.set_version(4); - ipv4_packet.set_flags(Ipv4Flags::DontFragment); - match next_protocol { - IpNextLevelProtocol::Tcp => { - ipv4_packet.set_total_length(52); - ipv4_packet.set_next_level_protocol(IpNextLevelProtocol::Tcp); - } - IpNextLevelProtocol::Udp => { - ipv4_packet.set_total_length(28); - ipv4_packet.set_next_level_protocol(IpNextLevelProtocol::Udp); - } - IpNextLevelProtocol::Icmp => { - ipv4_packet.set_total_length(28); - ipv4_packet.set_next_level_protocol(IpNextLevelProtocol::Icmp); - } - _ => {} - } - let checksum = nex_packet::ipv4::checksum(&ipv4_packet.to_immutable()); - ipv4_packet.set_checksum(checksum); -} - -/// IPv4 Packet Builder. -#[derive(Clone, Debug)] -pub struct Ipv4PacketBuilder { - pub src_ip: Ipv4Addr, - pub dst_ip: Ipv4Addr, - pub next_protocol: IpNextLevelProtocol, - pub total_length: Option, - pub identification: Option, - pub ttl: Option, - pub flags: Option, -} - -impl Ipv4PacketBuilder { - /// Constructs a new Ipv4PacketBuilder. - pub fn new(src_ip: Ipv4Addr, dst_ip: Ipv4Addr, next_protocol: IpNextLevelProtocol) -> Self { - Ipv4PacketBuilder { - src_ip, - dst_ip, - next_protocol, - total_length: None, - identification: None, - ttl: None, - flags: None, - } - } - /// Builds IPv4 packet and return bytes - pub fn build(&self) -> Vec { - let mut buffer = vec![0; IPV4_HEADER_LEN]; - let mut ipv4_packet = MutableIpv4Packet::new(&mut buffer).unwrap(); - ipv4_packet.set_header_length((IPV4_HEADER_LEN / IPV4_HEADER_LENGTH_BYTE_UNITS) as u8); - ipv4_packet.set_source(self.src_ip); - ipv4_packet.set_destination(self.dst_ip); - ipv4_packet.set_identification(self.identification.unwrap_or(rand::random::())); - ipv4_packet.set_ttl(self.ttl.unwrap_or(64)); - ipv4_packet.set_version(4); - ipv4_packet.set_next_level_protocol(self.next_protocol); - if let Some(flags) = self.flags { - match flags { - Ipv4Flags::DontFragment => { - ipv4_packet.set_flags(Ipv4Flags::DontFragment); - } - Ipv4Flags::MoreFragments => { - ipv4_packet.set_flags(Ipv4Flags::MoreFragments); - } - _ => {} - } - } else { - ipv4_packet.set_flags(Ipv4Flags::DontFragment); - } - match self.next_protocol { - IpNextLevelProtocol::Tcp => { - if let Some(total_length) = self.total_length { - ipv4_packet.set_total_length(total_length); - } else { - ipv4_packet.set_total_length(52); - } - } - IpNextLevelProtocol::Udp => { - if let Some(total_length) = self.total_length { - ipv4_packet.set_total_length(total_length); - } else { - ipv4_packet.set_total_length(28); - } - } - IpNextLevelProtocol::Icmp => { - if let Some(total_length) = self.total_length { - ipv4_packet.set_total_length(total_length); - } else { - ipv4_packet.set_total_length(28); - } - } - _ => {} - } - let checksum = nex_packet::ipv4::checksum(&ipv4_packet.to_immutable()); - ipv4_packet.set_checksum(checksum); - ipv4_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/ipv6.rs b/nex-packet-builder/src/ipv6.rs deleted file mode 100644 index 1cf63e6..0000000 --- a/nex-packet-builder/src/ipv6.rs +++ /dev/null @@ -1,92 +0,0 @@ -use nex_packet::ip::IpNextLevelProtocol; -use nex_packet::ipv6::MutableIpv6Packet; -use nex_packet::ipv6::IPV6_HEADER_LEN; -use nex_packet::Packet; -use std::net::Ipv6Addr; - -pub(crate) fn build_ipv6_packet( - ipv6_packet: &mut MutableIpv6Packet, - src_ip: Ipv6Addr, - dst_ip: Ipv6Addr, - next_protocol: IpNextLevelProtocol, -) { - ipv6_packet.set_source(src_ip); - ipv6_packet.set_destination(dst_ip); - ipv6_packet.set_version(6); - ipv6_packet.set_hop_limit(64); - match next_protocol { - IpNextLevelProtocol::Tcp => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Tcp); - ipv6_packet.set_payload_length(32); - } - IpNextLevelProtocol::Udp => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Udp); - ipv6_packet.set_payload_length(8); - } - IpNextLevelProtocol::Icmpv6 => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Icmpv6); - ipv6_packet.set_payload_length(8); - } - _ => {} - } -} - -/// IPv6 Packet Builder. -#[derive(Clone, Debug)] -pub struct Ipv6PacketBuilder { - /// Source IPv6 address. - pub src_ip: Ipv6Addr, - /// Destination IPv6 address. - pub dst_ip: Ipv6Addr, - /// Next level protocol. - pub next_protocol: IpNextLevelProtocol, - /// Payload Length. - pub payload_length: Option, - /// Hop Limit. - pub hop_limit: Option, -} - -impl Ipv6PacketBuilder { - /// Constructs a new Ipv6PacketBuilder. - pub fn new(src_ip: Ipv6Addr, dst_ip: Ipv6Addr, next_protocol: IpNextLevelProtocol) -> Self { - Ipv6PacketBuilder { - src_ip, - dst_ip, - next_protocol, - payload_length: None, - hop_limit: None, - } - } - /// Buid IPv6 packet and return bytes. - pub fn build(&self) -> Vec { - let mut buffer: Vec = vec![0; IPV6_HEADER_LEN]; - let mut ipv6_packet = MutableIpv6Packet::new(&mut buffer).unwrap(); - ipv6_packet.set_source(self.src_ip); - ipv6_packet.set_destination(self.dst_ip); - ipv6_packet.set_version(6); - if let Some(hop_limit) = self.hop_limit { - ipv6_packet.set_hop_limit(hop_limit); - } else { - ipv6_packet.set_hop_limit(64); - } - match self.next_protocol { - IpNextLevelProtocol::Tcp => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Tcp); - ipv6_packet.set_payload_length(32); - } - IpNextLevelProtocol::Udp => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Udp); - ipv6_packet.set_payload_length(8); - } - IpNextLevelProtocol::Icmpv6 => { - ipv6_packet.set_next_header(IpNextLevelProtocol::Icmpv6); - ipv6_packet.set_payload_length(8); - } - _ => {} - } - if let Some(payload_length) = self.payload_length { - ipv6_packet.set_payload_length(payload_length); - } - ipv6_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/ndp.rs b/nex-packet-builder/src/ndp.rs deleted file mode 100644 index 7750d3a..0000000 --- a/nex-packet-builder/src/ndp.rs +++ /dev/null @@ -1,63 +0,0 @@ -use nex_core::mac::MacAddr; -use nex_packet::ethernet::MAC_ADDR_LEN; -use nex_packet::icmpv6::ndp::{ - MutableNdpOptionPacket, MutableNeighborSolicitPacket, NdpOptionTypes, -}; -use nex_packet::icmpv6::ndp::{NDP_OPT_PACKET_LEN, NDP_SOL_PACKET_LEN}; -use nex_packet::icmpv6::{self, Icmpv6Type, MutableIcmpv6Packet}; -use std::net::Ipv6Addr; -//use nex_packet::Packet; - -/// Length in octets (8bytes) -fn octets_len(len: usize) -> u8 { - // 3 = log2(8) - (len.next_power_of_two() >> 3).try_into().unwrap() -} - -/// NDP Packet Builder. -#[derive(Clone, Debug)] -pub struct NdpPacketBuilder { - /// Source MAC address. - pub src_mac: MacAddr, - /// Destination MAC address. - pub dst_mac: MacAddr, - /// Source IPv6 address. - pub src_ip: Ipv6Addr, - /// Destination IPv6 address. - pub dst_ip: Ipv6Addr, -} - -impl NdpPacketBuilder { - /// Constructs a new NdpPacketBuilder. - pub fn new(src_mac: MacAddr, src_ip: Ipv6Addr, dst_ip: Ipv6Addr) -> NdpPacketBuilder { - NdpPacketBuilder { - src_mac: src_mac, - dst_mac: MacAddr::broadcast(), - src_ip: src_ip, - dst_ip: dst_ip, - } - } - /// Build ICMPv6 packet and return bytes. - pub fn build(&self) -> Vec { - let mut buffer = [0u8; NDP_SOL_PACKET_LEN + NDP_OPT_PACKET_LEN + MAC_ADDR_LEN]; - // Build the NDP packet - let mut ndp_packet = MutableNeighborSolicitPacket::new(&mut buffer).unwrap(); - ndp_packet.set_target_addr(self.dst_ip); - ndp_packet.set_icmpv6_type(Icmpv6Type::NeighborSolicitation); - ndp_packet.set_checksum(0x3131); - - let mut opt_packet = MutableNdpOptionPacket::new(ndp_packet.get_options_raw_mut()).unwrap(); - opt_packet.set_option_type(NdpOptionTypes::SourceLLAddr); - opt_packet.set_length(octets_len(MAC_ADDR_LEN)); - opt_packet.set_data(&self.src_mac.octets()); - - // Set the checksum (part of the NDP packet) - let mut icmpv6_packet = MutableIcmpv6Packet::new(&mut buffer).unwrap(); - icmpv6_packet.set_checksum(icmpv6::checksum( - &icmpv6_packet.to_immutable(), - &self.src_ip, - &self.dst_ip, - )); - buffer.to_vec() - } -} diff --git a/nex-packet-builder/src/tcp.rs b/nex-packet-builder/src/tcp.rs deleted file mode 100644 index a6a4633..0000000 --- a/nex-packet-builder/src/tcp.rs +++ /dev/null @@ -1,181 +0,0 @@ -use nex_packet::ethernet::ETHERNET_HEADER_LEN; -use nex_packet::ipv4::IPV4_HEADER_LEN; -use nex_packet::ipv6::IPV6_HEADER_LEN; -use nex_packet::tcp::TCP_MIN_DATA_OFFSET; -use nex_packet::tcp::{MutableTcpPacket, TcpFlags, TcpOption, TCP_HEADER_LEN}; -use nex_packet::Packet; -use std::net::{IpAddr, SocketAddr}; - -/// Default TCP Option Length. -pub const TCP_DEFAULT_OPTION_LEN: usize = 12; -/// Default TCP Source Port. -pub const DEFAULT_SRC_PORT: u16 = 53443; -/// TCP (IPv4) Minimum Packet Length. -pub const TCPV4_MINIMUM_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + TCP_HEADER_LEN; -/// TCP (IPv4) Default Packet Length. -pub const TCPV4_DEFAULT_PACKET_LEN: usize = - ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN; -/// TCP (IPv4) Minimum IP Packet Length. -pub const TCPV4_MINIMUM_IP_PACKET_LEN: usize = IPV4_HEADER_LEN + TCP_HEADER_LEN; -/// TCP (IPv4) Default IP Packet Length. -pub const TCPV4_DEFAULT_IP_PACKET_LEN: usize = - IPV4_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN; -/// TCP (IPv6) Minimum Packet Length. -pub const TCPV6_MINIMUM_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + TCP_HEADER_LEN; -/// TCP (IPv6) Default Packet Length. -pub const TCPV6_DEFAULT_PACKET_LEN: usize = - ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN; -/// TCP (IPv6) Minimum IP Packet Length. -pub const TCPV6_MINIMUM_IP_PACKET_LEN: usize = IPV6_HEADER_LEN + TCP_HEADER_LEN; -/// TCP (IPv6) Default IP Packet Length. -pub const TCPV6_DEFAULT_IP_PACKET_LEN: usize = - IPV6_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN; - -/// Get the length of TCP options from TCP data offset. -pub fn get_tcp_options_length(data_offset: u8) -> usize { - if data_offset > 5 { - data_offset as usize * 4 - TCP_HEADER_LEN - } else { - 0 - } -} - -/// Get the TCP data offset from TCP options. -pub fn get_tcp_data_offset(opstions: Vec) -> u8 { - let mut total_size: u8 = 0; - for opt in opstions { - total_size += opt.kind().size() as u8; - } - if total_size % 4 == 0 { - total_size / 4 + TCP_MIN_DATA_OFFSET - } else { - total_size / 4 + TCP_MIN_DATA_OFFSET + 1 - } -} - -pub(crate) fn build_tcp_packet( - tcp_packet: &mut MutableTcpPacket, - src_ip: IpAddr, - src_port: u16, - dst_ip: IpAddr, - dst_port: u16, -) { - tcp_packet.set_source(src_port); - tcp_packet.set_destination(dst_port); - tcp_packet.set_window(64240); - tcp_packet.set_data_offset(8); - tcp_packet.set_urgent_ptr(0); - tcp_packet.set_sequence(0); - tcp_packet.set_options(&[ - TcpOption::mss(1460), - TcpOption::sack_perm(), - TcpOption::nop(), - TcpOption::nop(), - TcpOption::wscale(7), - ]); - tcp_packet.set_flags(TcpFlags::SYN); - match src_ip { - IpAddr::V4(src_ip) => match dst_ip { - IpAddr::V4(dst_ip) => { - let checksum = - nex_packet::tcp::ipv4_checksum(&tcp_packet.to_immutable(), &src_ip, &dst_ip); - tcp_packet.set_checksum(checksum); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(src_ip) => match dst_ip { - IpAddr::V4(_) => {} - IpAddr::V6(dst_ip) => { - let checksum = - nex_packet::tcp::ipv6_checksum(&tcp_packet.to_immutable(), &src_ip, &dst_ip); - tcp_packet.set_checksum(checksum); - } - }, - } -} - -/// TCP Packet Builder. -#[derive(Clone, Debug)] -pub struct TcpPacketBuilder { - /// Source IP address. - pub src_ip: IpAddr, - /// Source port. - pub src_port: u16, - /// Destination IP address. - pub dst_ip: IpAddr, - /// Destination port. - pub dst_port: u16, - /// Window size. - pub window: u16, - /// TCP flags. - pub flags: u8, - /// TCP options. - pub options: Vec, - /// TCP payload. - pub payload: Vec, -} - -impl TcpPacketBuilder { - /// Constructs a new TcpPacketBuilder from Source SocketAddr and Destination SocketAddr with default options. - pub fn new(src_addr: SocketAddr, dst_addr: SocketAddr) -> TcpPacketBuilder { - TcpPacketBuilder { - src_ip: src_addr.ip(), - src_port: src_addr.port(), - dst_ip: dst_addr.ip(), - dst_port: dst_addr.port(), - window: 64240, - flags: TcpFlags::SYN, - options: vec![ - TcpOption::mss(1460), - TcpOption::sack_perm(), - TcpOption::nop(), - TcpOption::nop(), - TcpOption::wscale(7), - ], - payload: vec![], - } - } - /// Build a TCP packet and return bytes. - pub fn build(&self) -> Vec { - let data_offset = get_tcp_data_offset(self.options.clone()); - let tcp_options_len = get_tcp_options_length(data_offset); - let mut buffer: Vec = vec![0; TCP_HEADER_LEN + tcp_options_len + self.payload.len()]; - let mut tcp_packet = MutableTcpPacket::new(&mut buffer).unwrap(); - tcp_packet.set_source(self.src_port); - tcp_packet.set_destination(self.dst_port); - tcp_packet.set_window(self.window); - tcp_packet.set_data_offset(data_offset); - tcp_packet.set_urgent_ptr(0); - tcp_packet.set_sequence(0); - tcp_packet.set_flags(self.flags); - tcp_packet.set_options(&self.options); - if self.payload.len() > 0 { - tcp_packet.set_payload(&self.payload); - } - match self.src_ip { - IpAddr::V4(src_ip) => match self.dst_ip { - IpAddr::V4(dst_ip) => { - let checksum = nex_packet::tcp::ipv4_checksum( - &tcp_packet.to_immutable(), - &src_ip, - &dst_ip, - ); - tcp_packet.set_checksum(checksum); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(src_ip) => match self.dst_ip { - IpAddr::V4(_) => {} - IpAddr::V6(dst_ip) => { - let checksum = nex_packet::tcp::ipv6_checksum( - &tcp_packet.to_immutable(), - &src_ip, - &dst_ip, - ); - tcp_packet.set_checksum(checksum); - } - }, - } - tcp_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/udp.rs b/nex-packet-builder/src/udp.rs deleted file mode 100644 index 8e4cec1..0000000 --- a/nex-packet-builder/src/udp.rs +++ /dev/null @@ -1,113 +0,0 @@ -use nex_packet::ethernet::ETHERNET_HEADER_LEN; -use nex_packet::ipv4::IPV4_HEADER_LEN; -use nex_packet::ipv6::IPV6_HEADER_LEN; -use nex_packet::udp::MutableUdpPacket; -use nex_packet::udp::UDP_HEADER_LEN; -use nex_packet::Packet; -use std::net::{IpAddr, SocketAddr}; - -/// UDP BASE Destination Port. Usually used for traceroute. -pub const UDP_BASE_DST_PORT: u16 = 33435; - -/// UDP (IPv4) Minimum Packet Length. -pub const UDPV4_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN; -/// UDP (IPv4) Minimum IP Packet Length. -pub const UDPV4_IP_PACKET_LEN: usize = IPV4_HEADER_LEN + UDP_HEADER_LEN; -/// UDP (IPv6) Minimum Packet Length. -pub const UDPV6_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN; -/// UDP (IPv6) Minimum IP Packet Length. -pub const UDPV6_IP_PACKET_LEN: usize = IPV6_HEADER_LEN + UDP_HEADER_LEN; - -pub(crate) fn build_udp_packet( - udp_packet: &mut nex_packet::udp::MutableUdpPacket, - src_ip: IpAddr, - src_port: u16, - dst_ip: IpAddr, - dst_port: u16, -) { - udp_packet.set_length(8); - udp_packet.set_source(src_port); - udp_packet.set_destination(dst_port); - match src_ip { - IpAddr::V4(src_ip) => match dst_ip { - IpAddr::V4(dst_ip) => { - let checksum = - nex_packet::udp::ipv4_checksum(&udp_packet.to_immutable(), &src_ip, &dst_ip); - udp_packet.set_checksum(checksum); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(src_ip) => match dst_ip { - IpAddr::V4(_) => {} - IpAddr::V6(dst_ip) => { - let checksum = - nex_packet::udp::ipv6_checksum(&udp_packet.to_immutable(), &src_ip, &dst_ip); - udp_packet.set_checksum(checksum); - } - }, - } -} - -/// UDP Packet Builder. -#[derive(Clone, Debug)] -pub struct UdpPacketBuilder { - /// Source IP address. - pub src_ip: IpAddr, - /// Source Port. - pub src_port: u16, - /// Destination IP address. - pub dst_ip: IpAddr, - /// Destination Port. - pub dst_port: u16, - /// Payload. - pub payload: Vec, -} - -impl UdpPacketBuilder { - /// Constructs a new UdpPacketBuilder. - pub fn new(src_addr: SocketAddr, dst_addr: SocketAddr) -> Self { - UdpPacketBuilder { - src_ip: src_addr.ip(), - src_port: src_addr.port(), - dst_ip: dst_addr.ip(), - dst_port: dst_addr.port(), - payload: Vec::new(), - } - } - /// Builds a new UdpPacket and return bytes. - pub fn build(&self) -> Vec { - let mut buffer: Vec = vec![0; UDP_HEADER_LEN + self.payload.len()]; - let mut udp_packet = MutableUdpPacket::new(&mut buffer).unwrap(); - udp_packet.set_source(self.src_port); - udp_packet.set_destination(self.dst_port); - if self.payload.len() > 0 { - udp_packet.set_payload(&self.payload); - } - udp_packet.set_length(UDP_HEADER_LEN as u16 + self.payload.len() as u16); - match self.src_ip { - IpAddr::V4(src_ip) => match self.dst_ip { - IpAddr::V4(dst_ip) => { - let checksum = nex_packet::udp::ipv4_checksum( - &udp_packet.to_immutable(), - &src_ip, - &dst_ip, - ); - udp_packet.set_checksum(checksum); - } - IpAddr::V6(_) => {} - }, - IpAddr::V6(src_ip) => match self.dst_ip { - IpAddr::V4(_) => {} - IpAddr::V6(dst_ip) => { - let checksum = nex_packet::udp::ipv6_checksum( - &udp_packet.to_immutable(), - &src_ip, - &dst_ip, - ); - udp_packet.set_checksum(checksum); - } - }, - } - udp_packet.packet().to_vec() - } -} diff --git a/nex-packet-builder/src/util.rs b/nex-packet-builder/src/util.rs deleted file mode 100644 index d6c9e8b..0000000 --- a/nex-packet-builder/src/util.rs +++ /dev/null @@ -1,393 +0,0 @@ -use crate::icmp::build_icmp_echo_packet; -use crate::icmpv6::build_icmpv6_echo_packet; -use crate::ipv6::build_ipv6_packet; -use crate::tcp::{build_tcp_packet, TCP_DEFAULT_OPTION_LEN}; -use crate::udp::build_udp_packet; -use nex_core::mac::MacAddr; -use nex_packet::arp::{MutableArpPacket, ARP_HEADER_LEN}; -use nex_packet::ethernet::ETHERNET_HEADER_LEN; -use nex_packet::ethernet::{EtherType, MutableEthernetPacket}; -use nex_packet::icmp::ICMPV4_HEADER_LEN; -use nex_packet::icmpv6::ICMPV6_HEADER_LEN; -use nex_packet::ip::IpNextLevelProtocol; -use nex_packet::ipv4::{MutableIpv4Packet, IPV4_HEADER_LEN}; -use nex_packet::ipv6::{MutableIpv6Packet, IPV6_HEADER_LEN}; -use nex_packet::tcp::{MutableTcpPacket, TCP_HEADER_LEN}; -use nex_packet::udp::{MutableUdpPacket, UDP_HEADER_LEN}; -use nex_packet::Packet; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - -use crate::arp::build_arp_packet; -use crate::ethernet::{build_ethernet_arp_packet, build_ethernet_packet}; -use crate::ipv4::build_ipv4_packet; - -/// Higher level packet build option. -/// For building packet, use PacketBuilder or protocol specific packet builder. -#[derive(Clone, Debug)] -pub struct PacketBuildOption { - pub src_mac: MacAddr, - pub dst_mac: MacAddr, - pub ether_type: EtherType, - pub src_ip: IpAddr, - pub dst_ip: IpAddr, - pub src_port: Option, - pub dst_port: Option, - pub ip_protocol: Option, - pub payload: Vec, - pub use_tun: bool, -} - -impl PacketBuildOption { - /// Constructs a new PacketBuildOption. - pub fn new() -> Self { - PacketBuildOption { - src_mac: MacAddr::zero(), - dst_mac: MacAddr::zero(), - ether_type: EtherType::Ipv4, - src_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - dst_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - src_port: None, - dst_port: None, - ip_protocol: None, - payload: Vec::new(), - use_tun: false, - } - } -} - -/// Build ARP Packet from PacketBuildOption. -pub fn build_full_arp_packet(packet_option: PacketBuildOption) -> Vec { - let src_ip: Ipv4Addr = match packet_option.src_ip { - IpAddr::V4(ipv4_addr) => ipv4_addr, - _ => return Vec::new(), - }; - let dst_ip: Ipv4Addr = match packet_option.dst_ip { - IpAddr::V4(ipv4_addr) => ipv4_addr, - _ => return Vec::new(), - }; - let mut ethernet_buffer = [0u8; ETHERNET_HEADER_LEN + ARP_HEADER_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_arp_packet(&mut ethernet_packet, packet_option.src_mac.clone()); - let mut arp_buffer = [0u8; ARP_HEADER_LEN]; - let mut arp_packet = MutableArpPacket::new(&mut arp_buffer).unwrap(); - build_arp_packet( - &mut arp_packet, - packet_option.src_mac, - packet_option.dst_mac, - src_ip, - dst_ip, - ); - ethernet_packet.set_payload(arp_packet.packet()); - ethernet_packet.packet().to_vec() -} - -/// Build ICMP Packet from PacketBuildOption. Build full packet with ethernet and ipv4 header. -pub fn build_full_icmp_packet(packet_option: PacketBuildOption) -> Vec { - let src_ip: Ipv4Addr = match packet_option.src_ip { - IpAddr::V4(ipv4_addr) => ipv4_addr, - _ => return Vec::new(), - }; - let dst_ip: Ipv4Addr = match packet_option.dst_ip { - IpAddr::V4(ipv4_addr) => ipv4_addr, - _ => return Vec::new(), - }; - let mut ethernet_buffer = [0u8; ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + ICMPV4_HEADER_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv4_buffer = [0u8; IPV4_HEADER_LEN + ICMPV4_HEADER_LEN]; - let mut ipv4_packet = MutableIpv4Packet::new(&mut ipv4_buffer).unwrap(); - build_ipv4_packet( - &mut ipv4_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut icmp_buffer = [0u8; ICMPV4_HEADER_LEN]; - let mut icmp_packet = - nex_packet::icmp::echo_request::MutableEchoRequestPacket::new(&mut icmp_buffer).unwrap(); - build_icmp_echo_packet(&mut icmp_packet); - ipv4_packet.set_payload(icmp_packet.packet()); - ethernet_packet.set_payload(ipv4_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } -} - -/// Build ICMP Packet. -pub fn build_min_icmp_packet() -> Vec { - let mut icmp_buffer = [0u8; ICMPV4_HEADER_LEN]; - let mut icmp_packet = - nex_packet::icmp::echo_request::MutableEchoRequestPacket::new(&mut icmp_buffer).unwrap(); - build_icmp_echo_packet(&mut icmp_packet); - icmp_packet.packet().to_vec() -} - -/// Build ICMPv6 Packet from PacketBuildOption. Build full packet with ethernet and ipv6 header. -pub fn build_full_icmpv6_packet(packet_option: PacketBuildOption) -> Vec { - let src_ip: Ipv6Addr = match packet_option.src_ip { - IpAddr::V6(ipv6_addr) => ipv6_addr, - _ => return Vec::new(), - }; - let dst_ip: Ipv6Addr = match packet_option.dst_ip { - IpAddr::V6(ipv6_addr) => ipv6_addr, - _ => return Vec::new(), - }; - let mut ethernet_buffer = [0u8; ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + ICMPV6_HEADER_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv6_buffer = [0u8; IPV6_HEADER_LEN + ICMPV6_HEADER_LEN]; - let mut ipv6_packet = MutableIpv6Packet::new(&mut ipv6_buffer).unwrap(); - build_ipv6_packet( - &mut ipv6_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut icmpv6_buffer = [0u8; ICMPV6_HEADER_LEN]; - let mut icmpv6_packet = - nex_packet::icmpv6::echo_request::MutableEchoRequestPacket::new(&mut icmpv6_buffer) - .unwrap(); - build_icmpv6_echo_packet(&mut icmpv6_packet, src_ip, dst_ip); - ipv6_packet.set_payload(icmpv6_packet.packet()); - ethernet_packet.set_payload(ipv6_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } -} - -/// Build ICMPv6 Packet from PacketBuildOption. -pub fn build_min_icmpv6_packet(packet_option: PacketBuildOption) -> Vec { - let src_ip: Ipv6Addr = match packet_option.src_ip { - IpAddr::V6(ipv6_addr) => ipv6_addr, - _ => return Vec::new(), - }; - let dst_ip: Ipv6Addr = match packet_option.dst_ip { - IpAddr::V6(ipv6_addr) => ipv6_addr, - _ => return Vec::new(), - }; - let mut icmpv6_buffer = [0u8; ICMPV6_HEADER_LEN]; - let mut icmpv6_packet = - nex_packet::icmpv6::echo_request::MutableEchoRequestPacket::new(&mut icmpv6_buffer) - .unwrap(); - build_icmpv6_echo_packet(&mut icmpv6_packet, src_ip, dst_ip); - icmpv6_packet.packet().to_vec() -} - -/// Build TCP Packet from PacketBuildOption. Build full packet with Ethernet and IP header. -pub fn build_full_tcp_syn_packet(packet_option: PacketBuildOption) -> Vec { - match packet_option.src_ip { - IpAddr::V4(src_ip) => match packet_option.dst_ip { - IpAddr::V4(dst_ip) => { - let mut ethernet_buffer = [0u8; ETHERNET_HEADER_LEN - + IPV4_HEADER_LEN - + TCP_HEADER_LEN - + TCP_DEFAULT_OPTION_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv4_buffer = - [0u8; IPV4_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN]; - let mut ipv4_packet = MutableIpv4Packet::new(&mut ipv4_buffer).unwrap(); - build_ipv4_packet( - &mut ipv4_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut tcp_buffer = [0u8; TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN]; - let mut tcp_packet = MutableTcpPacket::new(&mut tcp_buffer).unwrap(); - build_tcp_packet( - &mut tcp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - ipv4_packet.set_payload(tcp_packet.packet()); - ethernet_packet.set_payload(ipv4_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } - } - IpAddr::V6(_) => return Vec::new(), - }, - IpAddr::V6(src_ip) => match packet_option.dst_ip { - IpAddr::V4(_) => return Vec::new(), - IpAddr::V6(dst_ip) => { - let mut ethernet_buffer = [0u8; ETHERNET_HEADER_LEN - + IPV6_HEADER_LEN - + TCP_HEADER_LEN - + TCP_DEFAULT_OPTION_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv6_buffer = - [0u8; IPV6_HEADER_LEN + TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN]; - let mut ipv6_packet = MutableIpv6Packet::new(&mut ipv6_buffer).unwrap(); - build_ipv6_packet( - &mut ipv6_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut tcp_buffer = [0u8; TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN]; - let mut tcp_packet = MutableTcpPacket::new(&mut tcp_buffer).unwrap(); - build_tcp_packet( - &mut tcp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - ipv6_packet.set_payload(tcp_packet.packet()); - ethernet_packet.set_payload(ipv6_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } - } - }, - } -} - -/// Build TCP Packet from PacketBuildOption. -pub fn build_min_tcp_syn_packet(packet_option: PacketBuildOption) -> Vec { - let mut tcp_buffer = [0u8; TCP_HEADER_LEN + TCP_DEFAULT_OPTION_LEN]; - let mut tcp_packet = MutableTcpPacket::new(&mut tcp_buffer).unwrap(); - build_tcp_packet( - &mut tcp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - tcp_packet.packet().to_vec() -} - -/// Build UDP Packet from PacketBuildOption. Build full packet with Ethernet and IP header. -pub fn build_full_udp_packet(packet_option: PacketBuildOption) -> Vec { - match packet_option.src_ip { - IpAddr::V4(src_ip) => match packet_option.dst_ip { - IpAddr::V4(dst_ip) => { - let mut ethernet_buffer = - [0u8; ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + UDP_HEADER_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv4_buffer = [0u8; IPV4_HEADER_LEN + UDP_HEADER_LEN]; - let mut ipv4_packet = MutableIpv4Packet::new(&mut ipv4_buffer).unwrap(); - build_ipv4_packet( - &mut ipv4_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut udp_buffer = [0u8; UDP_HEADER_LEN]; - let mut udp_packet = MutableUdpPacket::new(&mut udp_buffer).unwrap(); - build_udp_packet( - &mut udp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - ipv4_packet.set_payload(udp_packet.packet()); - ethernet_packet.set_payload(ipv4_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } - } - IpAddr::V6(_) => return Vec::new(), - }, - IpAddr::V6(src_ip) => match packet_option.dst_ip { - IpAddr::V4(_) => return Vec::new(), - IpAddr::V6(dst_ip) => { - let mut ethernet_buffer = - [0u8; ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN]; - let mut ethernet_packet: MutableEthernetPacket = - MutableEthernetPacket::new(&mut ethernet_buffer).unwrap(); - build_ethernet_packet( - &mut ethernet_packet, - packet_option.src_mac.clone(), - packet_option.dst_mac.clone(), - packet_option.ether_type, - ); - let mut ipv6_buffer = [0u8; IPV6_HEADER_LEN + UDP_HEADER_LEN]; - let mut ipv6_packet = MutableIpv6Packet::new(&mut ipv6_buffer).unwrap(); - build_ipv6_packet( - &mut ipv6_packet, - src_ip, - dst_ip, - packet_option.ip_protocol.unwrap(), - ); - let mut udp_buffer = [0u8; UDP_HEADER_LEN]; - let mut udp_packet = MutableUdpPacket::new(&mut udp_buffer).unwrap(); - build_udp_packet( - &mut udp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - ipv6_packet.set_payload(udp_packet.packet()); - ethernet_packet.set_payload(ipv6_packet.packet()); - if packet_option.use_tun { - ethernet_packet.packet()[ETHERNET_HEADER_LEN..].to_vec() - } else { - ethernet_packet.packet().to_vec() - } - } - }, - } -} - -/// Build UDP Packet from PacketBuildOption. -pub fn build_min_udp_packet(packet_option: PacketBuildOption) -> Vec { - let mut udp_buffer = [0u8; UDP_HEADER_LEN]; - let mut udp_packet = MutableUdpPacket::new(&mut udp_buffer).unwrap(); - build_udp_packet( - &mut udp_packet, - packet_option.src_ip, - packet_option.src_port.unwrap(), - packet_option.dst_ip, - packet_option.dst_port.unwrap(), - ); - udp_packet.packet().to_vec() -} diff --git a/nex-packet/Cargo.toml b/nex-packet/Cargo.toml index f27f0af..5a141e1 100644 --- a/nex-packet/Cargo.toml +++ b/nex-packet/Cargo.toml @@ -11,12 +11,10 @@ categories = ["network-programming"] license = "MIT" [dependencies] -rand = { workspace = true } -serde = { workspace = true, features = ["derive"], optional = true } +bytes = { workspace = true } nex-core = { workspace = true } -nex-macro = { workspace = true } -nex-macro-helper = { workspace = true } +serde = { workspace = true, features = ["derive"], optional = true } +rand = { workspace = true } [features] -clippy = [] serde = ["dep:serde", "nex-core/serde"] diff --git a/nex-packet/src/arp.rs b/nex-packet/src/arp.rs index 60728f7..53e0c76 100644 --- a/nex-packet/src/arp.rs +++ b/nex-packet/src/arp.rs @@ -1,12 +1,10 @@ //! ARP packet abstraction. -use crate::ethernet::{EtherType, ETHERNET_HEADER_LEN}; -use crate::PrimitiveValues; - -use alloc::vec::Vec; +use crate::{ethernet::{EtherType, ETHERNET_HEADER_LEN}, packet::Packet}; +use bytes::{Bytes, BytesMut}; use nex_core::mac::MacAddr; -use nex_macro::packet; +use core::fmt; use std::net::Ipv4Addr; #[cfg(feature = "serde")] @@ -17,58 +15,6 @@ pub const ARP_HEADER_LEN: usize = 28; /// ARP Minimum Packet Length. pub const ARP_PACKET_LEN: usize = ETHERNET_HEADER_LEN + ARP_HEADER_LEN; -/// Represents the ARP header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct ArpHeader { - pub hardware_type: ArpHardwareType, - pub protocol_type: EtherType, - pub hw_addr_len: u8, - pub proto_addr_len: u8, - pub operation: ArpOperation, - pub sender_hw_addr: MacAddr, - pub sender_proto_addr: Ipv4Addr, - pub target_hw_addr: MacAddr, - pub target_proto_addr: Ipv4Addr, -} - -impl ArpHeader { - /// Construct an ARP header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < ARP_HEADER_LEN { - return Err("Packet is too small for ARP header".to_string()); - } - match ArpPacket::new(packet) { - Some(arp_packet) => Ok(ArpHeader { - hardware_type: arp_packet.get_hardware_type(), - protocol_type: arp_packet.get_protocol_type(), - hw_addr_len: arp_packet.get_hw_addr_len(), - proto_addr_len: arp_packet.get_proto_addr_len(), - operation: arp_packet.get_operation(), - sender_hw_addr: arp_packet.get_sender_hw_addr(), - sender_proto_addr: arp_packet.get_sender_proto_addr(), - target_hw_addr: arp_packet.get_target_hw_addr(), - target_proto_addr: arp_packet.get_target_proto_addr(), - }), - None => Err("Failed to parse ARP packet".to_string()), - } - } - /// Construct an ARP header from an ArpPacket. - pub(crate) fn from_packet(packet: &ArpPacket) -> ArpHeader { - ArpHeader { - hardware_type: packet.get_hardware_type(), - protocol_type: packet.get_protocol_type(), - hw_addr_len: packet.get_hw_addr_len(), - proto_addr_len: packet.get_proto_addr_len(), - operation: packet.get_operation(), - sender_hw_addr: packet.get_sender_hw_addr(), - sender_proto_addr: packet.get_sender_proto_addr(), - target_hw_addr: packet.get_target_hw_addr(), - target_proto_addr: packet.get_target_proto_addr(), - } - } -} - /// Represents the ARP operation types. #[repr(u16)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -106,20 +52,30 @@ impl ArpOperation { _ => ArpOperation::Unknown(value), } } -} - -impl PrimitiveValues for ArpOperation { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - match *self { - ArpOperation::Request => (1,), - ArpOperation::Reply => (2,), - ArpOperation::RarpRequest => (3,), - ArpOperation::RarpReply => (4,), - ArpOperation::InRequest => (8,), - ArpOperation::InReply => (9,), - ArpOperation::Nak => (10,), - ArpOperation::Unknown(n) => (n,), + /// Return the name of the ArpOperation + pub fn name(&self) -> &str { + match self { + ArpOperation::Request => "ARP Request", + ArpOperation::Reply => "ARP Reply", + ArpOperation::RarpRequest => "RARP Request", + ArpOperation::RarpReply => "RARP Reply", + ArpOperation::InRequest => "InARP Request", + ArpOperation::InReply => "InARP Reply", + ArpOperation::Nak => "ARP NAK", + ArpOperation::Unknown(_) => "Unknown ARP Operation", + } + } + /// Return the value of the ArpOperation + pub fn value(&self) -> u16 { + match self { + ArpOperation::Request => 1, + ArpOperation::Reply => 2, + ArpOperation::RarpRequest => 3, + ArpOperation::RarpReply => 4, + ArpOperation::InRequest => 8, + ArpOperation::InReply => 9, + ArpOperation::Nak => 10, + ArpOperation::Unknown(value) => *value, } } } @@ -253,76 +209,321 @@ impl ArpHardwareType { _ => ArpHardwareType::Unknown(value), } } -} - -impl PrimitiveValues for ArpHardwareType { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - match *self { - ArpHardwareType::Ethernet => (1,), - ArpHardwareType::ExperimentalEthernet => (2,), - ArpHardwareType::AmateurRadioAX25 => (3,), - ArpHardwareType::ProteonProNETTokenRing => (4,), - ArpHardwareType::Chaos => (5,), - ArpHardwareType::IEEE802Networks => (6,), - ArpHardwareType::ARCNET => (7,), - ArpHardwareType::Hyperchannel => (8,), - ArpHardwareType::Lanstar => (9,), - ArpHardwareType::AutonetShortAddress => (10,), - ArpHardwareType::LocalTalk => (11,), - ArpHardwareType::LocalNet => (12,), - ArpHardwareType::UltraLink => (13,), - ArpHardwareType::SMDS => (14,), - ArpHardwareType::FrameRelay => (15,), - ArpHardwareType::AsynchronousTransmissionMode => (16,), - ArpHardwareType::HDLC => (17,), - ArpHardwareType::FibreChannel => (18,), - ArpHardwareType::AsynchronousTransmissionMode2 => (19,), - ArpHardwareType::SerialLine => (20,), - ArpHardwareType::AsynchronousTransmissionMode3 => (21,), - ArpHardwareType::MILSTD188220 => (22,), - ArpHardwareType::Metricom => (23,), - ArpHardwareType::IEEE13941995 => (24,), - ArpHardwareType::MAPOS => (25,), - ArpHardwareType::Twinaxial => (26,), - ArpHardwareType::EUI64 => (27,), - ArpHardwareType::HIPARP => (28,), - ArpHardwareType::IPandARPoverISO78163 => (29,), - ArpHardwareType::ARPSec => (30,), - ArpHardwareType::IPsecTunnel => (31,), - ArpHardwareType::InfiniBand => (32,), - ArpHardwareType::TIA102Project25CommonAirInterface => (16384,), - ArpHardwareType::WiegandInterface => (16385,), - ArpHardwareType::PureIP => (16386,), - ArpHardwareType::HWEXP1 => (65280,), - ArpHardwareType::HWEXP2 => (65281,), - ArpHardwareType::AEthernet => (65282,), - ArpHardwareType::Unknown(n) => (n,), + /// Return the name of the ARP hardware type + pub fn name(&self) -> &str { + match self { + ArpHardwareType::Ethernet => "Ethernet", + ArpHardwareType::ExperimentalEthernet => "Experimental Ethernet", + ArpHardwareType::AmateurRadioAX25 => "Amateur Radio AX.25", + ArpHardwareType::ProteonProNETTokenRing => "Proteon ProNET Token Ring", + ArpHardwareType::Chaos => "Chaos", + ArpHardwareType::IEEE802Networks => "IEEE 802 Networks", + ArpHardwareType::ARCNET => "ARCNET", + ArpHardwareType::Hyperchannel => "Hyperchannel", + ArpHardwareType::Lanstar => "Lanstar", + ArpHardwareType::AutonetShortAddress => "Autonet Short Address", + ArpHardwareType::LocalTalk => "LocalTalk", + ArpHardwareType::LocalNet => "LocalNet (IBM PCNet or SYTEK LocalNET)", + ArpHardwareType::UltraLink => "Ultra link", + ArpHardwareType::SMDS => "SMDS", + ArpHardwareType::FrameRelay => "Frame Relay", + ArpHardwareType::AsynchronousTransmissionMode => "Asynchronous Transmission Mode (ATM)", + ArpHardwareType::HDLC => "HDLC", + ArpHardwareType::FibreChannel => "Fibre Channel", + ArpHardwareType::AsynchronousTransmissionMode2 => "Asynchronous Transmission Mode (ATM) 2", + ArpHardwareType::SerialLine => "Serial Line", + ArpHardwareType::AsynchronousTransmissionMode3 => "Asynchronous Transmission Mode (ATM) 3", + ArpHardwareType::MILSTD188220 => "MIL-STD-188-220", + ArpHardwareType::Metricom => "Metricom", + ArpHardwareType::IEEE13941995 => "IEEE 1394.1995", + ArpHardwareType::MAPOS => "MAPOS", + ArpHardwareType::Twinaxial => "Twinaxial", + ArpHardwareType::EUI64 => "EUI-64", + ArpHardwareType::HIPARP => "HIPARP", + ArpHardwareType::IPandARPoverISO78163 => "IP and ARP over ISO 7816-3", + ArpHardwareType::ARPSec => "ARPSec", + ArpHardwareType::IPsecTunnel => "IPsec Tunnel", + ArpHardwareType::InfiniBand => "InfiniBand (TM)", + ArpHardwareType::TIA102Project25CommonAirInterface => "TIA-102 Project 25 Common Air Interface", + ArpHardwareType::WiegandInterface => "Wiegand Interface", + ArpHardwareType::PureIP => "Pure IP", + ArpHardwareType::HWEXP1 => "HW_EXP1", + ArpHardwareType::HWEXP2 => "HW_EXP2", + ArpHardwareType::AEthernet => "AEthernet", + ArpHardwareType::Unknown(_) => "Unknown ARP Hardware Type", + } + } + /// Return the value of the ARP hardware type + pub fn value(&self) -> u16 { + match self { + ArpHardwareType::Ethernet => 1, + ArpHardwareType::ExperimentalEthernet => 2, + ArpHardwareType::AmateurRadioAX25 => 3, + ArpHardwareType::ProteonProNETTokenRing => 4, + ArpHardwareType::Chaos => 5, + ArpHardwareType::IEEE802Networks => 6, + ArpHardwareType::ARCNET => 7, + ArpHardwareType::Hyperchannel => 8, + ArpHardwareType::Lanstar => 9, + ArpHardwareType::AutonetShortAddress => 10, + ArpHardwareType::LocalTalk => 11, + ArpHardwareType::LocalNet => 12, + ArpHardwareType::UltraLink => 13, + ArpHardwareType::SMDS => 14, + ArpHardwareType::FrameRelay => 15, + ArpHardwareType::AsynchronousTransmissionMode => 16, + ArpHardwareType::HDLC => 17, + ArpHardwareType::FibreChannel => 18, + ArpHardwareType::AsynchronousTransmissionMode2 => 19, + ArpHardwareType::SerialLine => 20, + ArpHardwareType::AsynchronousTransmissionMode3 => 21, + ArpHardwareType::MILSTD188220 => 22, + ArpHardwareType::Metricom => 23, + ArpHardwareType::IEEE13941995 => 24, + ArpHardwareType::MAPOS => 25, + ArpHardwareType::Twinaxial => 26, + ArpHardwareType::EUI64 => 27, + ArpHardwareType::HIPARP => 28, + ArpHardwareType::IPandARPoverISO78163 => 29, + ArpHardwareType::ARPSec => 30, + ArpHardwareType::IPsecTunnel => 31, + ArpHardwareType::InfiniBand => 32, + ArpHardwareType::TIA102Project25CommonAirInterface => 16384, + ArpHardwareType::WiegandInterface => 16385, + ArpHardwareType::PureIP => 16386, + ArpHardwareType::HWEXP1 => 65280, + ArpHardwareType::HWEXP2 => 65281, + ArpHardwareType::AEthernet => 65282, + ArpHardwareType::Unknown(value) => *value, } } } -/// Represents an ARP Packet. -#[packet] -#[allow(non_snake_case)] -pub struct Arp { - #[construct_with(u16)] +/// Represents the ARP header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ArpHeader { pub hardware_type: ArpHardwareType, - #[construct_with(u16)] pub protocol_type: EtherType, pub hw_addr_len: u8, pub proto_addr_len: u8, - #[construct_with(u16)] pub operation: ArpOperation, - #[construct_with(u8, u8, u8, u8, u8, u8)] pub sender_hw_addr: MacAddr, - #[construct_with(u8, u8, u8, u8)] pub sender_proto_addr: Ipv4Addr, - #[construct_with(u8, u8, u8, u8, u8, u8)] pub target_hw_addr: MacAddr, - #[construct_with(u8, u8, u8, u8)] pub target_proto_addr: Ipv4Addr, - #[payload] - #[length = "0"] - pub payload: Vec, } + +/// Represents an ARP Packet. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ArpPacket { + /// The ARP header. + pub header: ArpHeader, + /// The payload of the ARP packet. + pub payload: Bytes, +} + +impl Packet for ArpPacket { + type Header = ArpHeader; + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < ARP_HEADER_LEN { + return None; + } + let hardware_type = ArpHardwareType::new(u16::from_be_bytes([bytes[0], bytes[1]])); + let protocol_type = EtherType::new(u16::from_be_bytes([bytes[2], bytes[3]])); + let hw_addr_len = bytes[4]; + let proto_addr_len = bytes[5]; + let operation = ArpOperation::new(u16::from_be_bytes([bytes[6], bytes[7]])); + let sender_hw_addr = MacAddr::from_octets(bytes[8..14].try_into().unwrap()); + let sender_proto_addr = Ipv4Addr::new(bytes[14], bytes[15], bytes[16], bytes[17]); + let target_hw_addr = MacAddr::from_octets(bytes[18..24].try_into().unwrap()); + let target_proto_addr = Ipv4Addr::new(bytes[24], bytes[25], bytes[26], bytes[27]); + let payload = Bytes::copy_from_slice(&bytes[ARP_HEADER_LEN..]); + + Some(ArpPacket { + header: ArpHeader { + hardware_type, + protocol_type, + hw_addr_len, + proto_addr_len, + operation, + sender_hw_addr, + sender_proto_addr, + target_hw_addr, + target_proto_addr, + }, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = Vec::with_capacity(ARP_HEADER_LEN + self.payload.len()); + buf.extend_from_slice(&self.header.hardware_type.value().to_be_bytes()); + buf.extend_from_slice(&self.header.protocol_type.value().to_be_bytes()); + buf.push(self.header.hw_addr_len); + buf.push(self.header.proto_addr_len); + buf.extend_from_slice(&self.header.operation.value().to_be_bytes()); + buf.extend_from_slice(&self.header.sender_hw_addr.octets()); + buf.extend_from_slice(&self.header.sender_proto_addr.octets()); + buf.extend_from_slice(&self.header.target_hw_addr.octets()); + buf.extend_from_slice(&self.header.target_proto_addr.octets()); + buf.extend_from_slice(&self.payload); + + Bytes::from(buf) + } + + fn header(&self) -> Bytes { + self.to_bytes() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len (&self) -> usize { + ARP_HEADER_LEN + } + fn payload_len(&self) -> usize { + self.payload.len() + } + fn total_len(&self) -> usize { + ARP_HEADER_LEN + self.payload.len() + } + fn to_bytes_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.total_len()); + buf.extend_from_slice(&self.to_bytes()); + buf + } + fn header_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.header_len()); + buf.extend_from_slice(&self.header()); + buf + } + fn payload_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.payload_len()); + buf.extend_from_slice(&self.payload()); + buf + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) + } +} + +impl ArpPacket { + /// Create a new ARP packet. + pub fn new(header: ArpHeader, payload: Bytes) -> Self { + ArpPacket { header, payload } + } +} + +impl fmt::Display for ArpPacket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ArpPacket {{ hardware_type: {}, protocol_type: {}, hw_addr_len: {}, proto_addr_len: {}, operation: {}, sender_hw_addr: {}, sender_proto_addr: {}, target_hw_addr: {}, target_proto_addr: {} }}", + self.header.hardware_type.name(), + self.header.protocol_type.name(), + self.header.hw_addr_len, + self.header.proto_addr_len, + self.header.operation.name(), + self.header.sender_hw_addr, + self.header.sender_proto_addr, + self.header.target_hw_addr, + self.header.target_proto_addr + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_parse_valid_arp_packet() { + let raw = [ + 0x00, 0x01, // Hardware Type: Ethernet + 0x08, 0x00, // Protocol Type: IPv4 + 0x06, // HW Addr Len + 0x04, // Proto Addr Len + 0x00, 0x01, // Operation: Request + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Sender MAC + 192, 168, 1, 1, // Sender IP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Target MAC + 192, 168, 1, 2 // Target IP + ]; + + let padded = [&raw[..], &[0xde, 0xad, 0xbe, 0xef]].concat(); + let packet = ArpPacket::from_bytes(Bytes::copy_from_slice(&padded)).unwrap(); + + assert_eq!(packet.header.hardware_type, ArpHardwareType::Ethernet); + assert_eq!(packet.header.protocol_type, EtherType::Ipv4); + assert_eq!(packet.header.hw_addr_len, 6); + assert_eq!(packet.header.proto_addr_len, 4); + assert_eq!(packet.header.operation, ArpOperation::Request); + assert_eq!(packet.header.sender_hw_addr, MacAddr::from_octets([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff])); + assert_eq!(packet.header.sender_proto_addr, Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(packet.header.target_hw_addr, MacAddr::from_octets([0, 0, 0, 0, 0, 0])); + assert_eq!(packet.header.target_proto_addr, Ipv4Addr::new(192, 168, 1, 2)); + assert_eq!(packet.payload, Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef])); + } + + #[test] + fn test_serialize_roundtrip() { + let original = ArpPacket { + header: ArpHeader { + hardware_type: ArpHardwareType::Ethernet, + protocol_type: EtherType::Ipv4, + hw_addr_len: 6, + proto_addr_len: 4, + operation: ArpOperation::Reply, + sender_hw_addr: MacAddr::from_octets([1, 2, 3, 4, 5, 6]), + sender_proto_addr: Ipv4Addr::new(10, 0, 0, 1), + target_hw_addr: MacAddr::from_octets([10, 20, 30, 40, 50, 60]), + target_proto_addr: Ipv4Addr::new(10, 0, 0, 2), + }, + payload: Bytes::from_static(&[0xbe, 0xef]), + }; + + let bytes = original.to_bytes(); + let parsed = ArpPacket::from_bytes(bytes).unwrap(); + assert_eq!(original, parsed); + } + + #[test] + fn test_parse_invalid_short_packet() { + let short = Bytes::from_static(&[0u8; 10]); + assert!(ArpPacket::from_bytes(short).is_none()); + } + + #[test] + fn test_unknown_operation_and_hw_type() { + let raw = [ + 0x99, 0x99, // Hardware Type: unknown + 0x08, 0x00, // Protocol Type: IPv4 + 0x06, + 0x04, + 0x99, 0x99, // Operation: unknown + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 192, 168, 1, 1, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 192, 168, 1, 2, + 0x00, 0x01, 0x02, 0x03 + ]; + + let packet = ArpPacket::from_bytes(Bytes::copy_from_slice(&raw)).unwrap(); + match packet.header.hardware_type { + ArpHardwareType::Unknown(v) => assert_eq!(v, 0x9999), + _ => panic!("Expected unknown hardware type"), + } + match packet.header.operation { + ArpOperation::Unknown(v) => assert_eq!(v, 0x9999), + _ => panic!("Expected unknown operation"), + } + } +} + + diff --git a/nex-packet/src/builder/arp.rs b/nex-packet/src/builder/arp.rs new file mode 100644 index 0000000..311e64f --- /dev/null +++ b/nex-packet/src/builder/arp.rs @@ -0,0 +1,110 @@ +use bytes::Bytes; +use crate::{ + ethernet::EtherType, + packet::Packet, + arp::{ArpHeader, ArpPacket, ArpHardwareType, ArpOperation}, +}; +use nex_core::mac::MacAddr; +use std::net::Ipv4Addr; + +/// Builder for constructing ARP packets +#[derive(Debug, Clone)] +pub struct ArpPacketBuilder { + packet: ArpPacket, +} + +impl ArpPacketBuilder { + /// Create a new builder + pub fn new(sender_mac: MacAddr, sender_ip: Ipv4Addr, target_ip: Ipv4Addr) -> Self { + let header = ArpHeader { + hardware_type: ArpHardwareType::Ethernet, + protocol_type: EtherType::Ipv4, + hw_addr_len: 6, + proto_addr_len: 4, + operation: ArpOperation::Request, + sender_hw_addr: sender_mac, + sender_proto_addr: sender_ip, + target_hw_addr: MacAddr::zero(), + target_proto_addr: target_ip, + }; + Self { + packet: ArpPacket { + header, + payload: Bytes::new(), + }, + } + } + + pub fn hardware_type(mut self, hw_type: ArpHardwareType) -> Self { + self.packet.header.hardware_type = hw_type; + self + } + + pub fn protocol_type(mut self, proto_type: EtherType) -> Self { + self.packet.header.protocol_type = proto_type; + self + } + + /// Set the length of the sender MAC address + pub fn sender_hw_addr_len(mut self, len: u8) -> Self { + self.packet.header.hw_addr_len = len; + self + } + + /// Set the length of the sender IP address + pub fn sender_proto_addr_len(mut self, len: u8) -> Self { + self.packet.header.proto_addr_len = len; + self + } + + /// Set the sender MAC address + pub fn sender_mac(mut self, mac: MacAddr) -> Self { + self.packet.header.sender_hw_addr = mac; + self + } + + /// Set the sender IP address + pub fn sender_ip(mut self, ip: Ipv4Addr) -> Self { + self.packet.header.sender_proto_addr = ip; + self + } + + /// Set the target MAC address + pub fn target_mac(mut self, mac: MacAddr) -> Self { + self.packet.header.target_hw_addr = mac; + self + } + + /// Set the target IP address + pub fn target_ip(mut self, ip: Ipv4Addr) -> Self { + self.packet.header.target_proto_addr = ip; + self + } + + /// Set the ARP operation + pub fn operation(mut self, operation: ArpOperation) -> Self { + self.packet.header.operation = operation; + self + } + + /// Set an optional payload + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// Return the finished `ArpPacket` + pub fn build(self) -> ArpPacket { + self.packet + } + + /// Return the serialized bytes + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } + + /// Return a reference to the internal `ArpPacket` + pub fn packet(&self) -> &ArpPacket { + &self.packet + } +} diff --git a/nex-packet/src/builder/dhcp.rs b/nex-packet/src/builder/dhcp.rs new file mode 100644 index 0000000..a71910e --- /dev/null +++ b/nex-packet/src/builder/dhcp.rs @@ -0,0 +1,67 @@ +use std::net::Ipv4Addr; + +use bytes::Bytes; +use nex_core::mac::MacAddr; + +use crate::{dhcp::{DhcpHardwareType, DhcpHeader, DhcpOperation, DhcpPacket}, packet::Packet}; + +/// Builder for constructing DHCP packets +#[derive(Debug, Clone)] +pub struct DhcpPacketBuilder { + packet: DhcpPacket, +} + +impl DhcpPacketBuilder { + /// Create an initial builder for DHCP Discover (can be adapted for Request, Offer, etc.) + pub fn new_discover(xid: u32, chaddr: MacAddr) -> Self { + let header = DhcpHeader { + op: DhcpOperation::Request, + htype: DhcpHardwareType::Ethernet, + hlen: 6, + hops: 0, + xid, + secs: 0, + flags: 0x8000, // broadcast flag + ciaddr: Ipv4Addr::UNSPECIFIED, + yiaddr: Ipv4Addr::UNSPECIFIED, + siaddr: Ipv4Addr::UNSPECIFIED, + giaddr: Ipv4Addr::UNSPECIFIED, + chaddr, + chaddr_pad: [0u8; 10], + sname: [0u8; 64], + file: [0u8; 128], + }; + Self { + packet: DhcpPacket { + header, + payload: Bytes::new(), + }, + } + } + + /// Set the payload including options + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// Mutably access the header + pub fn header_mut(&mut self) -> &mut DhcpHeader { + &mut self.packet.header + } + + /// Build and return a `DhcpPacket` + pub fn build(self) -> DhcpPacket { + self.packet + } + + /// Build and return the packet bytes + pub fn to_bytes(self) -> Bytes { + self.packet.to_bytes() + } + + /// Get a reference to the packet + pub fn packet(&self) -> &DhcpPacket { + &self.packet + } +} diff --git a/nex-packet/src/builder/ethernet.rs b/nex-packet/src/builder/ethernet.rs new file mode 100644 index 0000000..c68771f --- /dev/null +++ b/nex-packet/src/builder/ethernet.rs @@ -0,0 +1,64 @@ +use crate::{ethernet::{EtherType, EthernetHeader, EthernetPacket}, packet::Packet}; +use nex_core::mac::MacAddr; +use bytes::Bytes; + +/// Builder for constructing Ethernet packets. +#[derive(Debug, Clone)] +pub struct EthernetPacketBuilder { + packet: EthernetPacket, +} + +impl EthernetPacketBuilder { + /// Create a new builder instance. + pub fn new() -> Self { + Self { + packet: EthernetPacket { + header: EthernetHeader { + destination: MacAddr::zero(), + source: MacAddr::zero(), + ethertype: EtherType::Ipv4, + }, + payload: Bytes::new(), + }, + } + } + + /// Set the destination MAC address. + pub fn destination(mut self, mac: MacAddr) -> Self { + self.packet.header.destination = mac; + self + } + + /// Set the source MAC address. + pub fn source(mut self, mac: MacAddr) -> Self { + self.packet.header.source = mac; + self + } + + /// Set the EtherType (IPv4, ARP, IPv6, etc.). + pub fn ethertype(mut self, ether_type: EtherType) -> Self { + self.packet.header.ethertype = ether_type; + self + } + + /// Set the payload bytes. + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// Consume the builder and produce an `EthernetPacket`. + pub fn build(self) -> EthernetPacket { + self.packet + } + + /// Serialize the packet into raw bytes. + pub fn to_bytes(self) -> Bytes { + self.packet.to_bytes() + } + + /// Retrieve only the header bytes. + pub fn header_bytes(&self) -> Bytes { + self.packet.header.to_bytes() + } +} diff --git a/nex-packet/src/builder/icmp.rs b/nex-packet/src/builder/icmp.rs new file mode 100644 index 0000000..7e18991 --- /dev/null +++ b/nex-packet/src/builder/icmp.rs @@ -0,0 +1,83 @@ +use std::net::Ipv4Addr; + +use bytes::{Bytes, BytesMut, BufMut}; +use crate::{icmp::{self, checksum, IcmpCode, IcmpHeader, IcmpPacket, IcmpType}, packet::Packet}; + +/// Builder for constructing ICMP packets +#[derive(Debug, Clone)] +pub struct IcmpPacketBuilder { + #[allow(unused)] + source: Ipv4Addr, + #[allow(unused)] + destination: Ipv4Addr, + packet: IcmpPacket, +} + +impl IcmpPacketBuilder { + /// Create a new builder with an initial ICMP Type and Code + pub fn new(source: Ipv4Addr, destination: Ipv4Addr) -> Self { + let header = IcmpHeader { + icmp_type: IcmpType::EchoRequest, + icmp_code: icmp::echo_request::IcmpCodes::NoCode, + checksum: 0, + }; + Self { + source, + destination, + packet: IcmpPacket { + header, + payload: Bytes::new(), + }, + } + } + + /// Set the ICMP Type + pub fn icmp_type(mut self, icmp_type: IcmpType) -> Self { + self.packet.header.icmp_type = icmp_type; + self + } + + /// Set the ICMP Code + pub fn icmp_code(mut self, icmp_code: IcmpCode) -> Self { + self.packet.header.icmp_code = icmp_code; + self + } + + /// Set an arbitrary payload + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// For Echo Request/Reply: place identifier and sequence number at the start of the payload + pub fn echo_fields(mut self, identifier: u16, sequence_number: u16) -> Self { + let mut buf = BytesMut::with_capacity(4 + self.packet.payload.len()); + buf.put_u16(identifier); + buf.put_u16(sequence_number); + buf.extend_from_slice(&self.packet.payload); + self.packet.payload = buf.freeze(); + self + } + + pub fn culculate_checksum(mut self) -> Self { + // Calculate the checksum and set it in the header + self.packet.header.checksum = checksum(&self.packet); + self + } + + /// Return an `IcmpPacket` with checksum computed + pub fn build(mut self) -> IcmpPacket { + self.packet.header.checksum = checksum(&self.packet); + self.packet + } + + /// Return the packet bytes with checksum computed + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } + + /// Access the intermediate `IcmpPacket` if needed + pub fn packet(&self) -> &IcmpPacket { + &self.packet + } +} diff --git a/nex-packet/src/builder/icmpv6.rs b/nex-packet/src/builder/icmpv6.rs new file mode 100644 index 0000000..a62f0ca --- /dev/null +++ b/nex-packet/src/builder/icmpv6.rs @@ -0,0 +1,82 @@ +use std::net::Ipv6Addr; + +use bytes::{Bytes, BytesMut, BufMut}; +use crate::{ + icmpv6::{self, checksum, Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type}, + packet::Packet, +}; + +/// Builder for constructing ICMPv6 packets +#[derive(Debug, Clone)] +pub struct Icmpv6PacketBuilder { + source: Ipv6Addr, + destination: Ipv6Addr, + packet: Icmpv6Packet, +} + +impl Icmpv6PacketBuilder { + /// Create a new builder with an initial ICMPv6 Type and Code + pub fn new(source: Ipv6Addr, destination: Ipv6Addr) -> Self { + let header = Icmpv6Header { + icmpv6_type: Icmpv6Type::EchoRequest, + icmpv6_code: icmpv6::echo_request::Icmpv6Codes::NoCode, + checksum: 0, + }; + Self { + source, + destination, + packet: Icmpv6Packet { + header, + payload: Bytes::new(), + }, + } + } + + pub fn icmpv6_type(mut self, icmpv6_type: Icmpv6Type) -> Self { + self.packet.header.icmpv6_type = icmpv6_type; + self + } + + pub fn icmpv6_code(mut self, icmpv6_code: Icmpv6Code) -> Self { + self.packet.header.icmpv6_code = icmpv6_code; + self + } + + /// Set an arbitrary payload + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// For Echo Request/Reply: place identifier and sequence number at the start of the payload + pub fn echo_fields(mut self, identifier: u16, sequence_number: u16) -> Self { + let mut buf = BytesMut::with_capacity(4 + self.packet.payload.len()); + buf.put_u16(identifier); + buf.put_u16(sequence_number); + buf.extend_from_slice(&self.packet.payload); + self.packet.payload = buf.freeze(); + self + } + + pub fn culculate_checksum(mut self) -> Self { + // Calculate the checksum and set it in the header + self.packet.header.checksum = checksum(&self.packet, &self.source, &self.destination); + self + } + + /// Return an `Icmpv6Packet` with checksum computed + pub fn build(mut self) -> Icmpv6Packet { + self.packet.header.checksum = checksum(&self.packet, &self.source, &self.destination); + self.packet + } + + /// Return the packet bytes with checksum computed + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } + + /// Access the intermediate `Icmpv6Packet` if needed + pub fn packet(&self) -> &Icmpv6Packet { + &self.packet + } +} diff --git a/nex-packet/src/builder/ipv4.rs b/nex-packet/src/builder/ipv4.rs new file mode 100644 index 0000000..6bbaaf6 --- /dev/null +++ b/nex-packet/src/builder/ipv4.rs @@ -0,0 +1,122 @@ +use crate::{ipv4::{Ipv4Header, Ipv4Packet, Ipv4OptionPacket, Ipv4OptionType}, ip::IpNextProtocol, packet::Packet}; +use bytes::Bytes; +use nex_core::bitfield::*; +use std::net::Ipv4Addr; + +/// Builder for constructing IPv4 packets. +#[derive(Debug, Clone)] +pub struct Ipv4PacketBuilder { + packet: Ipv4Packet, +} + +impl Ipv4PacketBuilder { + /// Create a new builder. + pub fn new() -> Self { + Self { + packet: Ipv4Packet { + header: Ipv4Header { + version: 4, + header_length: 5, + dscp: 0, + ecn: 0, + total_length: 0, // automatically set during build + identification: rand::random::(), + flags: 0, + fragment_offset: 0, + ttl: 64, + next_level_protocol: IpNextProtocol::new(0), + checksum: 0, + source: Ipv4Addr::UNSPECIFIED, + destination: Ipv4Addr::UNSPECIFIED, + options: vec![], + }, + payload: Bytes::new(), + }, + } + } + + pub fn source(mut self, addr: Ipv4Addr) -> Self { + self.packet.header.source = addr; + self + } + + pub fn destination(mut self, addr: Ipv4Addr) -> Self { + self.packet.header.destination = addr; + self + } + + pub fn ttl(mut self, ttl: u8) -> Self { + self.packet.header.ttl = ttl; + self + } + + pub fn protocol(mut self, proto: IpNextProtocol) -> Self { + self.packet.header.next_level_protocol = proto; + self + } + + pub fn identification(mut self, id: u16) -> Self { + self.packet.header.identification = id; + self + } + + pub fn flags(mut self, flags: u3) -> Self { + self.packet.header.flags = flags; + self + } + + pub fn fragment_offset(mut self, offset: u13be) -> Self { + self.packet.header.fragment_offset = offset; + self + } + + pub fn options(mut self, options: Vec) -> Self { + self.packet.header.options = options; + self.packet.header.header_length = ((20 + self.packet.header.options.iter().map(|opt| { + match opt.header.number { + Ipv4OptionType::EOL | Ipv4OptionType::NOP => 1, + _ => 2 + opt.data.len(), + } + }).sum::() + 3) / 4) as u4; // includes padding + self + } + + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + pub fn build(mut self) -> Ipv4Packet { + let total_length = self.packet.header_len() + self.packet.payload_len(); + self.packet.header.total_length = total_length as u16be; + self.packet.header.checksum = 0; + self.packet.header.checksum = crate::ipv4::checksum(&self.packet); + self.packet + } + + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ip::IpNextProtocol; + use bytes::Bytes; + use std::net::Ipv4Addr; + + #[test] + fn ipv4_builder_total_length() { + let payload = Bytes::from_static(&[1,2]); + let pkt = Ipv4PacketBuilder::new() + .source(Ipv4Addr::new(1,1,1,1)) + .destination(Ipv4Addr::new(2,2,2,2)) + .protocol(IpNextProtocol::Udp) + .payload(payload.clone()) + .build(); + assert_eq!(pkt.header.total_length, (pkt.header_len() + payload.len()) as u16); + assert_eq!(pkt.payload, payload); + } +} + diff --git a/nex-packet/src/builder/ipv6.rs b/nex-packet/src/builder/ipv6.rs new file mode 100644 index 0000000..78c099b --- /dev/null +++ b/nex-packet/src/builder/ipv6.rs @@ -0,0 +1,119 @@ +use std::net::Ipv6Addr; +use bytes::Bytes; +use crate::{ + ip::IpNextProtocol, + ipv6::{Ipv6Header, Ipv6Packet, Ipv6ExtensionHeader}, + packet::Packet, +}; + +/// Builder for constructing IPv6 packets +#[derive(Debug, Clone)] +pub struct Ipv6PacketBuilder { + packet: Ipv6Packet, +} + +impl Ipv6PacketBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + packet: Ipv6Packet { + header: Ipv6Header { + version: 6, + traffic_class: 0, + flow_label: 0, + payload_length: 0, + next_header: IpNextProtocol::Reserved, + hop_limit: 64, + source: Ipv6Addr::UNSPECIFIED, + destination: Ipv6Addr::UNSPECIFIED, + }, + extensions: Vec::new(), + payload: Bytes::new(), + }, + } + } + + pub fn source(mut self, addr: Ipv6Addr) -> Self { + self.packet.header.source = addr; + self + } + + pub fn destination(mut self, addr: Ipv6Addr) -> Self { + self.packet.header.destination = addr; + self + } + + pub fn traffic_class(mut self, tc: u8) -> Self { + self.packet.header.traffic_class = tc; + self + } + + pub fn flow_label(mut self, label: u32) -> Self { + self.packet.header.flow_label = label & 0x000FFFFF; + self + } + + pub fn hop_limit(mut self, limit: u8) -> Self { + self.packet.header.hop_limit = limit; + self + } + + pub fn next_header(mut self, proto: IpNextProtocol) -> Self { + self.packet.header.next_header = proto; + self + } + + pub fn extension(mut self, ext: Ipv6ExtensionHeader) -> Self { + self.packet.extensions.push(ext); + self + } + + pub fn extensions(mut self, list: Vec) -> Self { + self.packet.extensions = list; + self + } + + pub fn payload(mut self, payload: Bytes) -> Self { + self.packet.payload = payload; + self + } + + /// Build the packet and return it + pub fn build(mut self) -> Ipv6Packet { + let ext_len: usize = self.packet.extensions.iter().map(|e| e.len()).sum(); + self.packet.header.payload_length = (ext_len + self.packet.payload.len()) as u16; + self.packet + } + + /// Serialize the packet into bytes + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } + + /// Get only the header bytes + pub fn header_bytes(&self) -> Bytes { + self.packet.header().slice(..) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ip::IpNextProtocol; + use bytes::Bytes; + use std::net::Ipv6Addr; + + #[test] + fn ipv6_builder_payload_len() { + let payload = Bytes::from_static(&[1,2,3,4]); + let pkt = Ipv6PacketBuilder::new() + .source(Ipv6Addr::LOCALHOST) + .destination(Ipv6Addr::LOCALHOST) + .next_header(IpNextProtocol::Tcp) + .payload(payload.clone()) + .build(); + assert_eq!(pkt.header.payload_length, payload.len() as u16); + assert_eq!(pkt.payload, payload); + } +} + diff --git a/nex-packet-builder/src/lib.rs b/nex-packet/src/builder/mod.rs similarity index 58% rename from nex-packet-builder/src/lib.rs rename to nex-packet/src/builder/mod.rs index a196dc0..8f70f68 100644 --- a/nex-packet-builder/src/lib.rs +++ b/nex-packet/src/builder/mod.rs @@ -1,14 +1,10 @@ -//! Utilities designed to work with packets through high-level APIs. - pub mod arp; -pub mod builder; pub mod dhcp; pub mod ethernet; -pub mod icmp; -pub mod icmpv6; pub mod ipv4; pub mod ipv6; -pub mod ndp; +pub mod icmp; +pub mod icmpv6; pub mod tcp; pub mod udp; -pub mod util; +pub mod ndp; diff --git a/nex-packet/src/builder/ndp.rs b/nex-packet/src/builder/ndp.rs new file mode 100644 index 0000000..cecb80a --- /dev/null +++ b/nex-packet/src/builder/ndp.rs @@ -0,0 +1,84 @@ +use bytes::Bytes; +use std::net::Ipv6Addr; +use nex_core::mac::MacAddr; +use crate::icmpv6::{ + self, checksum, Icmpv6Header, Icmpv6Packet, Icmpv6Type +}; +use crate::icmpv6::ndp::{ + NdpOptionPacket, NdpOptionTypes, NeighborSolicitPacket +}; +use crate::packet::Packet; + +/// Length rounded up to an 8-byte multiple (for option length) +fn octets_len(len: usize) -> u8 { + ((len + 7) / 8) as u8 +} + +/// Builder for ICMPv6 Neighbor Solicitation packets +#[derive(Clone, Debug)] +pub struct NdpPacketBuilder { + /// Source MAC address + pub src_mac: MacAddr, + /// Destination MAC address + pub dst_mac: MacAddr, + /// Source IPv6 address + pub src_ip: Ipv6Addr, + /// Target (destination) IPv6 address + pub dst_ip: Ipv6Addr, +} + +impl NdpPacketBuilder { + /// Create a new builder + pub fn new(src_mac: MacAddr, src_ip: Ipv6Addr, dst_ip: Ipv6Addr) -> Self { + Self { + src_mac, + dst_mac: MacAddr::broadcast(), + src_ip, + dst_ip, + } + } + + /// Override the destination MAC + pub fn dst_mac(mut self, dst_mac: MacAddr) -> Self { + self.dst_mac = dst_mac; + self + } + + /// Build the Neighbor Solicitation packet + pub fn build(&self) -> Icmpv6Packet { + // Build the MAC address option + let mac_bytes = self.src_mac.octets(); + let opt_payload = Bytes::copy_from_slice(&mac_bytes); + let opt_len = octets_len(mac_bytes.len()); + + let options = vec![NdpOptionPacket { + option_type: NdpOptionTypes::SourceLLAddr, + length: opt_len, + payload: opt_payload, + }]; + + let packet = NeighborSolicitPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::NeighborSolicitation, + icmpv6_code: icmpv6::ndp::Icmpv6Codes::NoCode, + checksum: 0, + }, + reserved: 0, + target_addr: self.dst_ip, + options, + payload: Bytes::new(), + }; + + // Build an Icmpv6Packet and calculate the checksum + let mut icmp_packet = Icmpv6Packet::from_bytes(packet.to_bytes()) + .expect("Failed to create Icmpv6Packet from NeighborSolicitPacket"); + + icmp_packet.header.checksum = checksum(&icmp_packet, &self.src_ip, &self.dst_ip); + icmp_packet + } + + /// Get the packet as bytes + pub fn to_bytes(&self) -> Bytes { + self.build().to_bytes() + } +} diff --git a/nex-packet/src/builder/tcp.rs b/nex-packet/src/builder/tcp.rs new file mode 100644 index 0000000..639e2c9 --- /dev/null +++ b/nex-packet/src/builder/tcp.rs @@ -0,0 +1,126 @@ +use std::net::IpAddr; + +use crate::tcp::{TcpPacket, TcpHeader, TcpOptionPacket}; +use bytes::Bytes; +use crate::packet::Packet; + +/// Builder for constructing TCP packets +#[derive(Debug, Clone)] +pub struct TcpPacketBuilder { + packet: TcpPacket, +} + +impl TcpPacketBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + packet: TcpPacket { + header: TcpHeader { + source: 0, + destination: 0, + sequence: 0, + acknowledgement: 0, + data_offset: 5.into(), // default: header 20 bytes (5 * 4) + reserved: 0.into(), + flags: 0, + window: 0xffff, + checksum: 0, + urgent_ptr: 0, + options: Vec::new(), + }, + payload: Bytes::new(), + }, + } + } + + pub fn source(mut self, port: u16) -> Self { + self.packet.header.source = port.into(); + self + } + + pub fn destination(mut self, port: u16) -> Self { + self.packet.header.destination = port.into(); + self + } + + pub fn sequence(mut self, seq: u32) -> Self { + self.packet.header.sequence = seq.into(); + self + } + + pub fn acknowledgement(mut self, ack: u32) -> Self { + self.packet.header.acknowledgement = ack.into(); + self + } + + pub fn flags(mut self, flags: u8) -> Self { + self.packet.header.flags = flags; + self + } + + pub fn window(mut self, size: u16) -> Self { + self.packet.header.window = size.into(); + self + } + + pub fn urgent_ptr(mut self, ptr: u16) -> Self { + self.packet.header.urgent_ptr = ptr.into(); + self + } + + pub fn options(mut self, options: Vec) -> Self { + self.packet.header.options = options; + // Recalculate data offset (header length is in 4-byte units) + let base_len = 20; // base header + let opt_len: usize = self.packet.header.options.iter().map(|opt| opt.length() as usize).sum(); + let total = base_len + opt_len; + self.packet.header.data_offset = ((total + 3) / 4) as u8; // round up + self + } + + pub fn payload(mut self, data: Bytes) -> Self { + self.packet.payload = data; + self + } + + pub fn culculate_checksum(mut self, src_ip: &IpAddr, dst_ip: &IpAddr) -> Self { + // Calculate the checksum and set it in the header + self.packet.header.checksum = crate::tcp::checksum(&self.packet, src_ip, dst_ip); + self + } + pub fn build(self) -> TcpPacket { + self.packet + } + + pub fn to_bytes(self) -> Bytes { + self.packet.to_bytes() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tcp::TcpFlags; + use bytes::Bytes; + + #[test] + fn tcp_builder_basic() { + let pkt = TcpPacketBuilder::new() + .source(1234) + .destination(80) + .sequence(1) + .acknowledgement(2) + .flags(TcpFlags::SYN) + .window(1024) + .urgent_ptr(0) + .payload(Bytes::from_static(b"abc")) + .build(); + assert_eq!(pkt.header.source, 1234); + assert_eq!(pkt.header.destination, 80); + assert_eq!(pkt.header.sequence, 1); + assert_eq!(pkt.header.acknowledgement, 2); + assert_eq!(pkt.header.flags, TcpFlags::SYN); + assert_eq!(pkt.payload, Bytes::from_static(b"abc")); + } +} + diff --git a/nex-packet/src/builder/udp.rs b/nex-packet/src/builder/udp.rs new file mode 100644 index 0000000..dddb1e2 --- /dev/null +++ b/nex-packet/src/builder/udp.rs @@ -0,0 +1,96 @@ +use std::net::IpAddr; + +use crate::udp::{UdpPacket, UdpHeader, UDP_HEADER_LEN}; +use crate::packet::Packet; +use bytes::Bytes; + +/// Builder for constructing UDP packets +#[derive(Debug, Clone)] +pub struct UdpPacketBuilder { + packet: UdpPacket, +} + +impl UdpPacketBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + packet: UdpPacket { + header: UdpHeader { + source: 0, + destination: 0, + length: 0, // automatically set during build + checksum: 0, + }, + payload: Bytes::new(), + }, + } + } + + /// Set the source port + pub fn source(mut self, port: u16) -> Self { + self.packet.header.source = port.into(); + self + } + + /// Set the destination port + pub fn destination(mut self, port: u16) -> Self { + self.packet.header.destination = port.into(); + self + } + + /// Set the checksum (optional) + pub fn checksum(mut self, checksum: u16) -> Self { + self.packet.header.checksum = checksum.into(); + self + } + + /// Set the payload + pub fn payload(mut self, data: Bytes) -> Self { + self.packet.payload = data; + self + } + + pub fn culculate_checksum(mut self, src_ip: &IpAddr, dst_ip: &IpAddr) -> Self { + // Calculate the checksum and set it in the header + self.packet.header.checksum = crate::udp::checksum(&self.packet, src_ip, dst_ip); + self + } + + /// Build the packet + pub fn build(mut self) -> UdpPacket { + // Automatically compute the length + let total_len = UDP_HEADER_LEN + self.packet.payload.len(); + self.packet.header.length = (total_len as u16).into(); + self.packet + } + + /// Serialize the packet into bytes + pub fn to_bytes(self) -> Bytes { + self.build().to_bytes() + } + + /// Retrieve only the header bytes + pub fn header_bytes(&self) -> Bytes { + let mut pkt = self.clone().packet; + pkt.header.length = (UDP_HEADER_LEN + pkt.payload.len()) as u16; + pkt.header().clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn udp_builder_sets_length() { + let pkt = UdpPacketBuilder::new() + .source(1) + .destination(2) + .payload(Bytes::from_static(&[1,2,3])) + .build(); + assert_eq!(pkt.header.length, (UDP_HEADER_LEN + 3) as u16); + assert_eq!(pkt.payload, Bytes::from_static(&[1,2,3])); + } +} + diff --git a/nex-packet/src/dhcp.rs b/nex-packet/src/dhcp.rs index 2b596ab..ea004ef 100644 --- a/nex-packet/src/dhcp.rs +++ b/nex-packet/src/dhcp.rs @@ -1,17 +1,14 @@ -use crate::PrimitiveValues; - -use alloc::vec::Vec; - +use bytes::{Buf, BufMut, Bytes, BytesMut}; use nex_core::mac::MacAddr; -use nex_macro::packet; -use nex_macro_helper::types::*; use std::net::Ipv4Addr; +use crate::packet::Packet; + /// Minimum size of an DHCP packet. /// Options field is not included in this size. pub const DHCP_MIN_PACKET_SIZE: usize = 236; -/// Represents an DHCP operation. +// DHCP Operation Codes #[repr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum DhcpOperation { @@ -21,29 +18,24 @@ pub enum DhcpOperation { } impl DhcpOperation { - /// Constructs a new DhcpOperation from u8. - pub fn new(value: u8) -> DhcpOperation { + pub fn new(value: u8) -> Self { match value { - 1 => DhcpOperation::Request, - 2 => DhcpOperation::Reply, - _ => DhcpOperation::Unknown(value), + 1 => Self::Request, + 2 => Self::Reply, + other => Self::Unknown(other), } } -} -impl PrimitiveValues for DhcpOperation { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { + pub fn value(&self) -> u8 { match self { - &DhcpOperation::Request => (1,), - &DhcpOperation::Reply => (2,), - &DhcpOperation::Unknown(n) => (n,), + Self::Request => 1, + Self::Reply => 2, + Self::Unknown(v) => *v, } } } -/// Represents the Dhcp hardware types. -#[allow(non_snake_case)] +// DHCP Hardware Types #[repr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum DhcpHardwareType { @@ -62,10 +54,10 @@ pub enum DhcpHardwareType { UltraLink = 13, SMDS = 14, FrameRelay = 15, - AsynchronousTransmissionMode = 16, + ATM = 16, HDLC = 17, FibreChannel = 18, - AsynchronousTransmissionMode1 = 19, + ATM1 = 19, PropPointToPointSerial = 20, PPP = 21, SoftwareLoopback = 24, @@ -77,7 +69,7 @@ pub enum DhcpHardwareType { DS3 = 30, SIP = 31, FrameRelayInterconnect = 32, - AsynchronousTransmissionMode2 = 33, + ATM2 = 33, MILSTD188220 = 34, Metricom = 35, IEEE1394 = 37, @@ -89,7 +81,7 @@ pub enum DhcpHardwareType { ARPSec = 44, IPsecTunnel = 45, InfiniBand = 47, - TIA102Project25CommonAirInterface = 48, + TIA102CAI = 48, WiegandInterface = 49, PureIP = 50, HWExp1 = 51, @@ -112,174 +104,336 @@ pub enum DhcpHardwareType { } impl DhcpHardwareType { - /// Constructs a new DhcpHardwareType from u8 - pub fn new(n: u8) -> DhcpHardwareType { - match n { - 1 => DhcpHardwareType::Ethernet, - 2 => DhcpHardwareType::ExperimentalEthernet, - 3 => DhcpHardwareType::AmateurRadioAX25, - 4 => DhcpHardwareType::ProteonProNETTokenRing, - 5 => DhcpHardwareType::Chaos, - 6 => DhcpHardwareType::IEEE802Networks, - 7 => DhcpHardwareType::ARCNET, - 8 => DhcpHardwareType::Hyperchannel, - 9 => DhcpHardwareType::Lanstar, - 10 => DhcpHardwareType::AutonetShortAddress, - 11 => DhcpHardwareType::LocalTalk, - 12 => DhcpHardwareType::LocalNet, - 13 => DhcpHardwareType::UltraLink, - 14 => DhcpHardwareType::SMDS, - 15 => DhcpHardwareType::FrameRelay, - 16 => DhcpHardwareType::AsynchronousTransmissionMode, - 17 => DhcpHardwareType::HDLC, - 18 => DhcpHardwareType::FibreChannel, - 19 => DhcpHardwareType::AsynchronousTransmissionMode1, - 20 => DhcpHardwareType::PropPointToPointSerial, - 21 => DhcpHardwareType::PPP, - 24 => DhcpHardwareType::SoftwareLoopback, - 25 => DhcpHardwareType::EON, - 26 => DhcpHardwareType::Ethernet3MB, - 27 => DhcpHardwareType::NSIP, - 28 => DhcpHardwareType::Slip, - 29 => DhcpHardwareType::ULTRALink, - 30 => DhcpHardwareType::DS3, - 31 => DhcpHardwareType::SIP, - 32 => DhcpHardwareType::FrameRelayInterconnect, - 33 => DhcpHardwareType::AsynchronousTransmissionMode2, - 34 => DhcpHardwareType::MILSTD188220, - 35 => DhcpHardwareType::Metricom, - 37 => DhcpHardwareType::IEEE1394, - 39 => DhcpHardwareType::MAPOS, - 40 => DhcpHardwareType::Twinaxial, - 41 => DhcpHardwareType::EUI64, - 42 => DhcpHardwareType::HIPARP, - 43 => DhcpHardwareType::IPandARPoverISO7816_3, - 44 => DhcpHardwareType::ARPSec, - 45 => DhcpHardwareType::IPsecTunnel, - 47 => DhcpHardwareType::InfiniBand, - 48 => DhcpHardwareType::TIA102Project25CommonAirInterface, - 49 => DhcpHardwareType::WiegandInterface, - 50 => DhcpHardwareType::PureIP, - 51 => DhcpHardwareType::HWExp1, - 52 => DhcpHardwareType::HFI, - 53 => DhcpHardwareType::HWExp2, - 54 => DhcpHardwareType::AEthernet, - 55 => DhcpHardwareType::HWExp3, - 56 => DhcpHardwareType::IPsecTransport, - 57 => DhcpHardwareType::SDLCRadio, - 58 => DhcpHardwareType::SDLCMultipoint, - 59 => DhcpHardwareType::IWARP, - 61 => DhcpHardwareType::SixLoWPAN, - 62 => DhcpHardwareType::VLAN, - 63 => DhcpHardwareType::ProviderBridging, - 64 => DhcpHardwareType::IEEE802154, - 65 => DhcpHardwareType::MAPOSinIPv4, - 66 => DhcpHardwareType::MAPOSinIPv6, - 70 => DhcpHardwareType::IEEE802154NonASKPHY, - _ => DhcpHardwareType::Unknown(n), + pub fn new(value: u8) -> Self { + use DhcpHardwareType::*; + match value { + 1 => Ethernet, + 2 => ExperimentalEthernet, + 3 => AmateurRadioAX25, + 4 => ProteonProNETTokenRing, + 5 => Chaos, + 6 => IEEE802Networks, + 7 => ARCNET, + 8 => Hyperchannel, + 9 => Lanstar, + 10 => AutonetShortAddress, + 11 => LocalTalk, + 12 => LocalNet, + 13 => UltraLink, + 14 => SMDS, + 15 => FrameRelay, + 16 => ATM, + 17 => HDLC, + 18 => FibreChannel, + 19 => ATM1, + 20 => PropPointToPointSerial, + 21 => PPP, + 24 => SoftwareLoopback, + 25 => EON, + 26 => Ethernet3MB, + 27 => NSIP, + 28 => Slip, + 29 => ULTRALink, + 30 => DS3, + 31 => SIP, + 32 => FrameRelayInterconnect, + 33 => ATM2, + 34 => MILSTD188220, + 35 => Metricom, + 37 => IEEE1394, + 39 => MAPOS, + 40 => Twinaxial, + 41 => EUI64, + 42 => HIPARP, + 43 => IPandARPoverISO7816_3, + 44 => ARPSec, + 45 => IPsecTunnel, + 47 => InfiniBand, + 48 => TIA102CAI, + 49 => WiegandInterface, + 50 => PureIP, + 51 => HWExp1, + 52 => HFI, + 53 => HWExp2, + 54 => AEthernet, + 55 => HWExp3, + 56 => IPsecTransport, + 57 => SDLCRadio, + 58 => SDLCMultipoint, + 59 => IWARP, + 61 => SixLoWPAN, + 62 => VLAN, + 63 => ProviderBridging, + 64 => IEEE802154, + 65 => MAPOSinIPv4, + 66 => MAPOSinIPv6, + 70 => IEEE802154NonASKPHY, + other => Unknown(other), } } -} -impl PrimitiveValues for DhcpHardwareType { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { + pub fn value(&self) -> u8 { match self { - &DhcpHardwareType::Ethernet => (1,), - &DhcpHardwareType::ExperimentalEthernet => (2,), - &DhcpHardwareType::AmateurRadioAX25 => (3,), - &DhcpHardwareType::ProteonProNETTokenRing => (4,), - &DhcpHardwareType::Chaos => (5,), - &DhcpHardwareType::IEEE802Networks => (6,), - &DhcpHardwareType::ARCNET => (7,), - &DhcpHardwareType::Hyperchannel => (8,), - &DhcpHardwareType::Lanstar => (9,), - &DhcpHardwareType::AutonetShortAddress => (10,), - &DhcpHardwareType::LocalTalk => (11,), - &DhcpHardwareType::LocalNet => (12,), - &DhcpHardwareType::UltraLink => (13,), - &DhcpHardwareType::SMDS => (14,), - &DhcpHardwareType::FrameRelay => (15,), - &DhcpHardwareType::AsynchronousTransmissionMode => (16,), - &DhcpHardwareType::HDLC => (17,), - &DhcpHardwareType::FibreChannel => (18,), - &DhcpHardwareType::AsynchronousTransmissionMode1 => (19,), - &DhcpHardwareType::PropPointToPointSerial => (20,), - &DhcpHardwareType::PPP => (21,), - &DhcpHardwareType::SoftwareLoopback => (24,), - &DhcpHardwareType::EON => (25,), - &DhcpHardwareType::Ethernet3MB => (26,), - &DhcpHardwareType::NSIP => (27,), - &DhcpHardwareType::Slip => (28,), - &DhcpHardwareType::ULTRALink => (29,), - &DhcpHardwareType::DS3 => (30,), - &DhcpHardwareType::SIP => (31,), - &DhcpHardwareType::FrameRelayInterconnect => (32,), - &DhcpHardwareType::AsynchronousTransmissionMode2 => (33,), - &DhcpHardwareType::MILSTD188220 => (34,), - &DhcpHardwareType::Metricom => (35,), - &DhcpHardwareType::IEEE1394 => (37,), - &DhcpHardwareType::MAPOS => (39,), - &DhcpHardwareType::Twinaxial => (40,), - &DhcpHardwareType::EUI64 => (41,), - &DhcpHardwareType::HIPARP => (42,), - &DhcpHardwareType::IPandARPoverISO7816_3 => (43,), - &DhcpHardwareType::ARPSec => (44,), - &DhcpHardwareType::IPsecTunnel => (45,), - &DhcpHardwareType::InfiniBand => (47,), - &DhcpHardwareType::TIA102Project25CommonAirInterface => (48,), - &DhcpHardwareType::WiegandInterface => (49,), - &DhcpHardwareType::PureIP => (50,), - &DhcpHardwareType::HWExp1 => (51,), - &DhcpHardwareType::HFI => (52,), - &DhcpHardwareType::HWExp2 => (53,), - &DhcpHardwareType::AEthernet => (54,), - &DhcpHardwareType::HWExp3 => (55,), - &DhcpHardwareType::IPsecTransport => (56,), - &DhcpHardwareType::SDLCRadio => (57,), - &DhcpHardwareType::SDLCMultipoint => (58,), - &DhcpHardwareType::IWARP => (59,), - &DhcpHardwareType::SixLoWPAN => (61,), - &DhcpHardwareType::VLAN => (62,), - &DhcpHardwareType::ProviderBridging => (63,), - &DhcpHardwareType::IEEE802154 => (64,), - &DhcpHardwareType::MAPOSinIPv4 => (65,), - &DhcpHardwareType::MAPOSinIPv6 => (66,), - &DhcpHardwareType::IEEE802154NonASKPHY => (70,), - &DhcpHardwareType::Unknown(n) => (n,), + DhcpHardwareType::Ethernet => 1, + DhcpHardwareType::ExperimentalEthernet => 2, + DhcpHardwareType::AmateurRadioAX25 => 3, + DhcpHardwareType::ProteonProNETTokenRing => 4, + DhcpHardwareType::Chaos => 5, + DhcpHardwareType::IEEE802Networks => 6, + DhcpHardwareType::ARCNET => 7, + DhcpHardwareType::Hyperchannel => 8, + DhcpHardwareType::Lanstar => 9, + DhcpHardwareType::AutonetShortAddress => 10, + DhcpHardwareType::LocalTalk => 11, + DhcpHardwareType::LocalNet => 12, + DhcpHardwareType::UltraLink => 13, + DhcpHardwareType::SMDS => 14, + DhcpHardwareType::FrameRelay => 15, + DhcpHardwareType::ATM => 16, + DhcpHardwareType::HDLC => 17, + DhcpHardwareType::FibreChannel => 18, + DhcpHardwareType::ATM1 => 19, + DhcpHardwareType::PropPointToPointSerial => 20, + DhcpHardwareType::PPP => 21, + DhcpHardwareType::SoftwareLoopback => 24, + DhcpHardwareType::EON => 25, + DhcpHardwareType::Ethernet3MB => 26, + DhcpHardwareType::NSIP => 27, + DhcpHardwareType::Slip => 28, + DhcpHardwareType::ULTRALink => 29, + DhcpHardwareType::DS3 => 30, + DhcpHardwareType::SIP => 31, + DhcpHardwareType::FrameRelayInterconnect => 32, + DhcpHardwareType::ATM2 => 33, + DhcpHardwareType::MILSTD188220 => 34, + DhcpHardwareType::Metricom => 35, + DhcpHardwareType::IEEE1394 => 37, + DhcpHardwareType::MAPOS => 39, + DhcpHardwareType::Twinaxial => 40, + DhcpHardwareType::EUI64 => 41, + DhcpHardwareType::HIPARP => 42, + DhcpHardwareType::IPandARPoverISO7816_3 => 43, + DhcpHardwareType::ARPSec => 44, + DhcpHardwareType::IPsecTunnel => 45, + DhcpHardwareType::InfiniBand => 47, + DhcpHardwareType::TIA102CAI => 48, + DhcpHardwareType::WiegandInterface => 49, + DhcpHardwareType::PureIP => 50, + DhcpHardwareType::HWExp1 => 51, + DhcpHardwareType::HFI => 52, + DhcpHardwareType::HWExp2 => 53, + DhcpHardwareType::AEthernet => 54, + DhcpHardwareType::HWExp3 => 55, + DhcpHardwareType::IPsecTransport => 56, + DhcpHardwareType::SDLCRadio => 57, + DhcpHardwareType::SDLCMultipoint => 58, + DhcpHardwareType::IWARP => 59, + DhcpHardwareType::SixLoWPAN => 61, + DhcpHardwareType::VLAN => 62, + DhcpHardwareType::ProviderBridging => 63, + DhcpHardwareType::IEEE802154 => 64, + DhcpHardwareType::MAPOSinIPv4 => 65, + DhcpHardwareType::MAPOSinIPv6 => 66, + DhcpHardwareType::IEEE802154NonASKPHY => 70, + DhcpHardwareType::Unknown(n) => *n, } } } -/// Represents an DHCP Packet. -#[packet] -#[allow(non_snake_case)] -pub struct Dhcp { - #[construct_with(u8)] +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DhcpHeader { pub op: DhcpOperation, - #[construct_with(u8)] pub htype: DhcpHardwareType, pub hlen: u8, pub hops: u8, - pub xid: u32be, - pub secs: u16be, - pub flags: u16be, - #[construct_with(u8, u8, u8, u8)] + pub xid: u32, + pub secs: u16, + pub flags: u16, pub ciaddr: Ipv4Addr, - #[construct_with(u8, u8, u8, u8)] pub yiaddr: Ipv4Addr, - #[construct_with(u8, u8, u8, u8)] pub siaddr: Ipv4Addr, - #[construct_with(u8, u8, u8, u8)] pub giaddr: Ipv4Addr, - #[construct_with(u8, u8, u8, u8, u8, u8)] pub chaddr: MacAddr, - #[length = "10"] - pub chaddr_pad: Vec, - #[length = "64"] - pub sname: Vec, - #[length = "128"] - pub file: Vec, - #[payload] - pub options: Vec, + pub chaddr_pad: [u8; 10], + pub sname: [u8; 64], + pub file: [u8; 128], +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DhcpPacket { + pub header: DhcpHeader, + pub payload: Bytes, +} + +impl Packet for DhcpPacket { + type Header = DhcpHeader; + + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < DHCP_MIN_PACKET_SIZE { + return None; + } + + let op = DhcpOperation::new(bytes.get_u8()); + let htype = DhcpHardwareType::new(bytes.get_u8()); + let hlen = bytes.get_u8(); + let hops = bytes.get_u8(); + let xid = bytes.get_u32(); + let secs = bytes.get_u16(); + let flags = bytes.get_u16(); + + let ciaddr = Ipv4Addr::from(bytes.get_u32()); + let yiaddr = Ipv4Addr::from(bytes.get_u32()); + let siaddr = Ipv4Addr::from(bytes.get_u32()); + let giaddr = Ipv4Addr::from(bytes.get_u32()); + + let mut chaddr = [0u8; 6]; + bytes.copy_to_slice(&mut chaddr); + let chaddr = MacAddr::from_octets(chaddr); + + let mut chaddr_pad = [0u8; 10]; + bytes.copy_to_slice(&mut chaddr_pad); + + let mut sname = [0u8; 64]; + bytes.copy_to_slice(&mut sname); + + let mut file = [0u8; 128]; + bytes.copy_to_slice(&mut file); + + let header = DhcpHeader { + op, + htype, + hlen, + hops, + xid, + secs, + flags, + ciaddr, + yiaddr, + siaddr, + giaddr, + chaddr, + chaddr_pad, + sname, + file, + }; + + Some(Self { + header, + payload: Bytes::copy_from_slice(bytes), + }) + } + + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(DHCP_MIN_PACKET_SIZE + self.payload.len()); + + buf.put_u8(self.header.op.value()); + buf.put_u8(self.header.htype.value()); + buf.put_u8(self.header.hlen); + buf.put_u8(self.header.hops); + buf.put_u32(self.header.xid); + buf.put_u16(self.header.secs); + buf.put_u16(self.header.flags); + + buf.put_slice(&self.header.ciaddr.octets()); + buf.put_slice(&self.header.yiaddr.octets()); + buf.put_slice(&self.header.siaddr.octets()); + buf.put_slice(&self.header.giaddr.octets()); + + buf.put_slice(&self.header.chaddr.octets()); + buf.put_slice(&self.header.chaddr_pad); + buf.put_slice(&self.header.sname); + buf.put_slice(&self.header.file); + + buf.extend_from_slice(&self.payload); + + buf.freeze() + } + fn header(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(DHCP_MIN_PACKET_SIZE); + + buf.put_u8(self.header.op.value()); + buf.put_u8(self.header.htype.value()); + buf.put_u8(self.header.hlen); + buf.put_u8(self.header.hops); + buf.put_u32(self.header.xid); + buf.put_u16(self.header.secs); + buf.put_u16(self.header.flags); + + buf.put_slice(&self.header.ciaddr.octets()); + buf.put_slice(&self.header.yiaddr.octets()); + buf.put_slice(&self.header.siaddr.octets()); + buf.put_slice(&self.header.giaddr.octets()); + + buf.put_slice(&self.header.chaddr.octets()); + buf.put_slice(&self.header.chaddr_pad); + buf.put_slice(&self.header.sname); + buf.put_slice(&self.header.file); + + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + DHCP_MIN_PACKET_SIZE + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use nex_core::mac::MacAddr; + + #[test] + fn test_dhcp_packet_from_bytes_and_to_bytes() { + let raw = { + let mut buf = BytesMut::with_capacity(DHCP_MIN_PACKET_SIZE); + buf.put_u8(1); // op: Request + buf.put_u8(1); // htype: Ethernet + buf.put_u8(6); // hlen + buf.put_u8(0); // hops + buf.put_u32(0x12345678); // xid + buf.put_u16(0); // secs + buf.put_u16(0); // flags + buf.put_slice(&[0, 0, 0, 0]); // ciaddr + buf.put_slice(&[0, 0, 0, 0]); // yiaddr + buf.put_slice(&[0, 0, 0, 0]); // siaddr + buf.put_slice(&[0, 0, 0, 0]); // giaddr + buf.put_slice(&[0x00, 0x11, 0x22, 0x33, 0x44, 0x55]); // chaddr + buf.extend_from_slice(&[0u8; 10]); // chaddr_pad + buf.extend_from_slice(&[0u8; 64]); // sname + buf.extend_from_slice(&[0u8; 128]); // file + buf.freeze() + }; + + let packet = DhcpPacket::from_bytes(raw.clone()).expect("Failed to parse DHCP packet"); + + assert_eq!(packet.header.op, DhcpOperation::Request); + assert_eq!(packet.header.htype, DhcpHardwareType::Ethernet); + assert_eq!(packet.header.hlen, 6); + assert_eq!(packet.header.xid, 0x12345678); + assert_eq!(packet.header.chaddr, MacAddr::new(0x00, 0x11, 0x22, 0x33, 0x44, 0x55)); + + let rebuilt = packet.to_bytes(); + assert_eq!(rebuilt, raw); + } } diff --git a/nex-packet/src/dns.rs b/nex-packet/src/dns.rs index 5bf74ea..4c06910 100644 --- a/nex-packet/src/dns.rs +++ b/nex-packet/src/dns.rs @@ -1,360 +1,435 @@ -use alloc::string::String; -use alloc::vec::Vec; -use core::{fmt, str}; -use nex_macro::packet; -use nex_macro_helper::packet::{Packet, PacketSize, PrimitiveValues}; -use nex_macro_helper::types::{u1, u16be, u32be, u4}; +use core::str; use std::str::Utf8Error; +use bytes::{BufMut, Bytes, BytesMut}; +use nex_core::bitfield::{u1, u16be, u32be}; +use crate::packet::Packet; /// Represents an DNS operation. /// These identifiers correspond to DNS resource record classes. /// -#[allow(non_snake_case)] -#[allow(non_upper_case_globals)] -pub mod DnsClasses { - use super::DnsClass; - - /// Internet - pub const IN: DnsClass = DnsClass(1); - /// CSNET (Unassigned) - pub const CS: DnsClass = DnsClass(2); - /// Chaos - pub const CH: DnsClass = DnsClass(3); - /// Hesiod - pub const HS: DnsClass = DnsClass(4); -} - -/// Represents a DNS class. +#[repr(u16)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DnsClass(pub u16); +pub enum DnsClass { + IN = 1, // Internet + CS = 2, // CSNET (Obsolete) + CH = 3, // Chaos + HS = 4, // Hesiod + Unknown(u16), +} impl DnsClass { pub fn new(value: u16) -> Self { - Self(value) + match value { + 1 => DnsClass::IN, + 2 => DnsClass::CS, + 3 => DnsClass::CH, + 4 => DnsClass::HS, + v => DnsClass::Unknown(v), + } } -} -impl PrimitiveValues for DnsClass { - type T = (u16,); - - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + pub fn value(&self) -> u16 { + match self { + DnsClass::IN => 1, + DnsClass::CS => 2, + DnsClass::CH => 3, + DnsClass::HS => 4, + DnsClass::Unknown(v) => *v, + } } -} -impl fmt::Display for DnsClass { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - match self { - &DnsClasses::IN => "IN", // 1 - &DnsClasses::CS => "CS", // 2 - &DnsClasses::CH => "CH", // 3 - &DnsClasses::HS => "HS", // 4 - _ => "unknown", - } - ) + pub fn name(&self) -> &'static str { + match self { + DnsClass::IN => "IN", + DnsClass::CS => "CS", + DnsClass::CH => "CH", + DnsClass::HS => "HS", + DnsClass::Unknown(_) => "Unknown", + } } } -/// Represents an DNS types. -/// These identifiers are used to specify the type of DNS query or response. -/// -#[allow(non_snake_case)] -#[allow(non_upper_case_globals)] -pub mod DnsTypes { - use super::DnsType; - - pub const A: DnsType = DnsType(1); - pub const NS: DnsType = DnsType(2); - pub const MD: DnsType = DnsType(3); - pub const MF: DnsType = DnsType(4); - pub const CNAME: DnsType = DnsType(5); - pub const SOA: DnsType = DnsType(6); - pub const MB: DnsType = DnsType(7); - pub const MG: DnsType = DnsType(8); - pub const MR: DnsType = DnsType(9); - pub const NULL: DnsType = DnsType(10); - pub const WKS: DnsType = DnsType(11); - pub const PTR: DnsType = DnsType(12); - pub const HINFO: DnsType = DnsType(13); - pub const MINFO: DnsType = DnsType(14); - pub const MX: DnsType = DnsType(15); - pub const TXT: DnsType = DnsType(16); - pub const RP: DnsType = DnsType(17); - pub const AFSDB: DnsType = DnsType(18); - pub const X25: DnsType = DnsType(19); - pub const ISDN: DnsType = DnsType(20); - pub const RT: DnsType = DnsType(21); - pub const NSAP: DnsType = DnsType(22); - pub const NSAP_PTR: DnsType = DnsType(23); - pub const SIG: DnsType = DnsType(24); - pub const KEY: DnsType = DnsType(25); - pub const PX: DnsType = DnsType(26); - pub const GPOS: DnsType = DnsType(27); - pub const AAAA: DnsType = DnsType(28); - pub const LOC: DnsType = DnsType(29); - pub const NXT: DnsType = DnsType(30); - pub const EID: DnsType = DnsType(31); - pub const NIMLOC: DnsType = DnsType(32); - pub const SRV: DnsType = DnsType(33); - pub const ATMA: DnsType = DnsType(34); - pub const NAPTR: DnsType = DnsType(35); - pub const KX: DnsType = DnsType(36); - pub const CERT: DnsType = DnsType(37); - pub const A6: DnsType = DnsType(38); - pub const DNAME: DnsType = DnsType(39); - pub const SINK: DnsType = DnsType(40); - pub const OPT: DnsType = DnsType(41); - pub const APL: DnsType = DnsType(42); - pub const DS: DnsType = DnsType(43); - pub const SSHFP: DnsType = DnsType(44); - pub const IPSECKEY: DnsType = DnsType(45); - pub const RRSIG: DnsType = DnsType(46); - pub const NSEC: DnsType = DnsType(47); - pub const DNSKEY: DnsType = DnsType(48); - pub const DHCID: DnsType = DnsType(49); - pub const NSEC3: DnsType = DnsType(50); - pub const NSEC3PARAM: DnsType = DnsType(51); - pub const TLSA: DnsType = DnsType(52); - pub const SMIMEA: DnsType = DnsType(53); - pub const HIP: DnsType = DnsType(55); - pub const NINFO: DnsType = DnsType(56); - pub const RKEY: DnsType = DnsType(57); - pub const TALINK: DnsType = DnsType(58); - pub const CDS: DnsType = DnsType(59); - pub const CDNSKEY: DnsType = DnsType(60); - pub const OPENPGPKEY: DnsType = DnsType(61); - pub const CSYNC: DnsType = DnsType(62); - pub const ZONEMD: DnsType = DnsType(63); - pub const SVCB: DnsType = DnsType(64); - pub const HTTPS: DnsType = DnsType(65); - pub const SPF: DnsType = DnsType(99); - pub const UINFO: DnsType = DnsType(100); - pub const UID: DnsType = DnsType(101); - pub const GID: DnsType = DnsType(102); - pub const UNSPEC: DnsType = DnsType(103); - pub const NID: DnsType = DnsType(104); - pub const L32: DnsType = DnsType(105); - pub const L64: DnsType = DnsType(106); - pub const LP: DnsType = DnsType(107); - pub const EUI48: DnsType = DnsType(108); - pub const EUI64: DnsType = DnsType(109); - pub const TKEY: DnsType = DnsType(249); - pub const TSIG: DnsType = DnsType(250); - pub const IXFR: DnsType = DnsType(251); - pub const AXFR: DnsType = DnsType(252); - pub const MAILB: DnsType = DnsType(253); - pub const MAILA: DnsType = DnsType(254); - pub const ANY: DnsType = DnsType(255); - pub const URI: DnsType = DnsType(256); - pub const CAA: DnsType = DnsType(257); - pub const AVC: DnsType = DnsType(258); - pub const DOA: DnsType = DnsType(259); - pub const AMTRELAY: DnsType = DnsType(260); - pub const TA: DnsType = DnsType(32768); - pub const DLV: DnsType = DnsType(32769); -} - -/// Represents a DNS type. +#[allow(non_camel_case_types)] +#[repr(u16)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DnsType(pub u16); +pub enum DnsType { + A = 1, + NS = 2, + MD = 3, + MF = 4, + CNAME = 5, + SOA = 6, + MB = 7, + MG = 8, + MR = 9, + NULL = 10, + WKS = 11, + PTR = 12, + HINFO = 13, + MINFO = 14, + MX = 15, + TXT = 16, + RP = 17, + AFSDB = 18, + X25 = 19, + ISDN = 20, + RT = 21, + NSAP = 22, + NSAP_PTR = 23, + SIG = 24, + KEY = 25, + PX = 26, + GPOS = 27, + AAAA = 28, + LOC = 29, + NXT = 30, + EID = 31, + NIMLOC = 32, + SRV = 33, + ATMA = 34, + NAPTR = 35, + KX = 36, + CERT = 37, + A6 = 38, + DNAME = 39, + SINK = 40, + OPT = 41, + APL = 42, + DS = 43, + SSHFP = 44, + IPSECKEY = 45, + RRSIG = 46, + NSEC = 47, + DNSKEY = 48, + DHCID = 49, + NSEC3 = 50, + NSEC3PARAM = 51, + TLSA = 52, + SMIMEA = 53, + HIP = 55, + NINFO = 56, + RKEY = 57, + TALINK = 58, + CDS = 59, + CDNSKEY = 60, + OPENPGPKEY = 61, + CSYNC = 62, + ZONEMD = 63, + SVCB = 64, + HTTPS = 65, + SPF = 99, + UINFO = 100, + UID = 101, + GID = 102, + UNSPEC = 103, + NID = 104, + L32 = 105, + L64 = 106, + LP = 107, + EUI48 = 108, + EUI64 = 109, + TKEY = 249, + TSIG = 250, + IXFR = 251, + AXFR = 252, + MAILB = 253, + MAILA = 254, + ANY = 255, + URI = 256, + CAA = 257, + AVC = 258, + DOA = 259, + AMTRELAY = 260, + TA = 32768, + DLV = 32769, + Unknown(u16), +} impl DnsType { pub fn new(value: u16) -> Self { - Self(value) - } -} - -impl PrimitiveValues for DnsType { - type T = (u16,); - - fn to_primitive_values(&self) -> (u16,) { - (self.0,) - } -} - -impl fmt::Display for DnsType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}", - match self { - &DnsTypes::A => "A", // 1 - &DnsTypes::NS => "NS", // 2 - &DnsTypes::MD => "MD", // 3 - &DnsTypes::MF => "MF", // 4 - &DnsTypes::CNAME => "CNAME", // 5 - &DnsTypes::SOA => "SOA", // 6 - &DnsTypes::MB => "MB", // 7 - &DnsTypes::MG => "MG", // 8 - &DnsTypes::MR => "MR", // 9 - &DnsTypes::NULL => "NULL", // 10 - &DnsTypes::WKS => "WKS", // 11 - &DnsTypes::PTR => "PTR", // 12 - &DnsTypes::HINFO => "HINFO", // 13 - &DnsTypes::MINFO => "MINFO", // 14 - &DnsTypes::MX => "MX", // 15 - &DnsTypes::TXT => "TXT", // 16 - &DnsTypes::RP => "RP", // 17 - &DnsTypes::AFSDB => "AFSDB", // 18 - &DnsTypes::X25 => "X25", // 19 - &DnsTypes::ISDN => "ISDN", // 20 - &DnsTypes::RT => "RT", // 21 - &DnsTypes::NSAP => "NSAP", // 22 - &DnsTypes::NSAP_PTR => "NSAP_PTR", // 23 - &DnsTypes::SIG => "SIG", // 24 - &DnsTypes::KEY => "KEY", // 25 - &DnsTypes::PX => "PX", // 26 - &DnsTypes::GPOS => "GPOS", // 27 - &DnsTypes::AAAA => "AAAA", // 28 - &DnsTypes::LOC => "LOC", // 29 - &DnsTypes::NXT => "NXT", // 30 - &DnsTypes::EID => "EID", // 31 - &DnsTypes::NIMLOC => "NIMLOC", // 32 - &DnsTypes::SRV => "SRV", // 33 - &DnsTypes::ATMA => "ATMA", // 34 - &DnsTypes::NAPTR => "NAPTR", // 35 - &DnsTypes::KX => "KX", // 36 - &DnsTypes::CERT => "CERT", // 37 - &DnsTypes::A6 => "A6", // 38 - &DnsTypes::DNAME => "DNAME", // 39 - &DnsTypes::SINK => "SINK", // 40 - &DnsTypes::OPT => "OPT", // 41 - &DnsTypes::APL => "APL", // 42 - &DnsTypes::DS => "DS", // 43 - &DnsTypes::SSHFP => "SSHFP", // 44 - &DnsTypes::IPSECKEY => "IPSECKEY", // 45 - &DnsTypes::RRSIG => "RRSIG", // 46 - &DnsTypes::NSEC => "NSEC", // 47 - &DnsTypes::DNSKEY => "DNSKEY", // 48 - &DnsTypes::DHCID => "DHCID", // 49 - &DnsTypes::NSEC3 => "NSEC3", // 50 - &DnsTypes::NSEC3PARAM => "NSEC3PARAM", // 51 - &DnsTypes::TLSA => "TLSA", // 52 - &DnsTypes::SMIMEA => "SMIMEA", // 53 - &DnsTypes::HIP => "HIP", // 55 - &DnsTypes::NINFO => "NINFO", // 56 - &DnsTypes::RKEY => "RKEY", // 57 - &DnsTypes::TALINK => "TALINK", // 58 - &DnsTypes::CDS => "CDS", // 59 - &DnsTypes::CDNSKEY => "CDNSKEY", // 60 - &DnsTypes::OPENPGPKEY => "OPENPGPKEY", // 61 - &DnsTypes::CSYNC => "CSYNC", // 62 - &DnsTypes::ZONEMD => "ZONEMD", // 63 - &DnsTypes::SVCB => "SVCB", // 64 - &DnsTypes::HTTPS => "HTTPS", // 65 - &DnsTypes::SPF => "SPF", // 99 - &DnsTypes::UINFO => "UINFO", // 100 - &DnsTypes::UID => "UID", // 101 - &DnsTypes::GID => "GID", // 102 - &DnsTypes::UNSPEC => "UNSPEC", // 103 - &DnsTypes::NID => "NID", // 104 - &DnsTypes::L32 => "L32", // 105 - &DnsTypes::L64 => "L64", // 106 - &DnsTypes::LP => "LP", // 107 - &DnsTypes::EUI48 => "EUI48", // 108 - &DnsTypes::EUI64 => "EUI64", // 109 - &DnsTypes::TKEY => "TKEY", // 249 - &DnsTypes::TSIG => "TSIG", // 250 - &DnsTypes::IXFR => "IXFR", // 251 - &DnsTypes::AXFR => "AXFR", // 252 - &DnsTypes::MAILB => "MAILB", // 253 - &DnsTypes::MAILA => "MAILA", // 254 - &DnsTypes::ANY => "ANY", // 255 - &DnsTypes::URI => "URI", // 256 - &DnsTypes::CAA => "CAA", // 257 - &DnsTypes::AVC => "AVC", // 258 - &DnsTypes::DOA => "DOA", // 259 - &DnsTypes::AMTRELAY => "AMTRELAY", // 260 - &DnsTypes::TA => "TA", // 32768 - &DnsTypes::DLV => "DLV", // 32769 - _ => "unknown", - } - ) - } -} - -/// Represents a DNS packet. -/// Including its header and all the associated records. -#[packet] -pub struct Dns { - pub id: u16be, - pub is_response: u1, - #[construct_with(u4)] - pub opcode: OpCode, - pub is_authoriative: u1, - pub is_truncated: u1, - pub is_recursion_desirable: u1, - pub is_recursion_available: u1, - pub zero_reserved: u1, - pub is_answer_authenticated: u1, - pub is_non_authenticated_data: u1, - #[construct_with(u4)] - pub rcode: RetCode, - pub query_count: u16be, - pub response_count: u16be, - pub authority_rr_count: u16be, - pub additional_rr_count: u16be, - #[length_fn = "queries_length"] - pub queries: Vec, - #[length_fn = "responses_length"] - pub responses: Vec, - #[length_fn = "authority_length"] - pub authorities: Vec, - #[length_fn = "additional_length"] - pub additionals: Vec, - #[payload] - pub payload: Vec, -} - -fn queries_length(packet: &DnsPacket) -> usize { - let base = 12; - let mut length = 0; - for _ in 0..packet.get_query_count() { - match DnsQueryPacket::new(&packet.packet()[base + length..]) { - Some(query) => length += query.packet_size(), - None => break, - } - } - length -} - -fn responses_length(packet: &DnsPacket) -> usize { - let base = 12 + queries_length(packet); - let mut length = 0; - for _ in 0..packet.get_response_count() { - match DnsResponsePacket::new(&packet.packet()[base + length..]) { - Some(query) => length += query.packet_size(), - None => break, + match value { + 1 => DnsType::A, + 2 => DnsType::NS, + 3 => DnsType::MD, + 4 => DnsType::MF, + 5 => DnsType::CNAME, + 6 => DnsType::SOA, + 7 => DnsType::MB, + 8 => DnsType::MG, + 9 => DnsType::MR, + 10 => DnsType::NULL, + 11 => DnsType::WKS, + 12 => DnsType::PTR, + 13 => DnsType::HINFO, + 14 => DnsType::MINFO, + 15 => DnsType::MX, + 16 => DnsType::TXT, + 17 => DnsType::RP, + 18 => DnsType::AFSDB, + 19 => DnsType::X25, + 20 => DnsType::ISDN, + 21 => DnsType::RT, + 22 => DnsType::NSAP, + 23 => DnsType::NSAP_PTR, + 24 => DnsType::SIG, + 25 => DnsType::KEY, + 26 => DnsType::PX, + 27 => DnsType::GPOS, + 28 => DnsType::AAAA, + 29 => DnsType::LOC, + 30 => DnsType::NXT, + 31 => DnsType::EID, + 32 => DnsType::NIMLOC, + 33 => DnsType::SRV, + 34 => DnsType::ATMA, + 35 => DnsType::NAPTR, + 36 => DnsType::KX, + 37 => DnsType::CERT, + 38 => DnsType::A6, + 39 => DnsType::DNAME, + 40 => DnsType::SINK, + 41 => DnsType::OPT, + 42 => DnsType::APL, + 43 => DnsType::DS, + 44 => DnsType::SSHFP, + 45 => DnsType::IPSECKEY, + 46 => DnsType::RRSIG, + 47 => DnsType::NSEC, + 48 => DnsType::DNSKEY, + 49 => DnsType::DHCID, + 50 => DnsType::NSEC3, + 51 => DnsType::NSEC3PARAM, + 52 => DnsType::TLSA, + 53 => DnsType::SMIMEA, + 55 => DnsType::HIP, + 56 => DnsType::NINFO, + 57 => DnsType::RKEY, + 58 => DnsType::TALINK, + 59 => DnsType::CDS, + 60 => DnsType::CDNSKEY, + 61 => DnsType::OPENPGPKEY, + 62 => DnsType::CSYNC, + 63 => DnsType::ZONEMD, + 64 => DnsType::SVCB, + 65 => DnsType::HTTPS, + 99 => DnsType::SPF, + 100 => DnsType::UINFO, + 101 => DnsType::UID, + 102 => DnsType::GID, + 103 => DnsType::UNSPEC, + 104 => DnsType::NID, + 105 => DnsType::L32, + 106 => DnsType::L64, + 107 => DnsType::LP, + 108 => DnsType::EUI48, + 109 => DnsType::EUI64, + 249 => DnsType::TKEY, + 250 => DnsType::TSIG, + 251 => DnsType::IXFR, + 252 => DnsType::AXFR, + 253 => DnsType::MAILB, + 254 => DnsType::MAILA, + 255 => DnsType::ANY, + 256 => DnsType::URI, + 257 => DnsType::CAA, + 258 => DnsType::AVC, + 259 => DnsType::DOA, + 260 => DnsType::AMTRELAY, + 32768 => DnsType::TA, + 32769 => DnsType::DLV, + v => DnsType::Unknown(v), } } - length -} -fn authority_length(packet: &DnsPacket) -> usize { - let base = 12 + queries_length(packet) + responses_length(packet); - let mut length = 0; - for _ in 0..packet.get_authority_rr_count() { - match DnsResponsePacket::new(&packet.packet()[base + length..]) { - Some(query) => length += query.packet_size(), - None => break, + pub fn value(&self) -> u16 { + match self { + DnsType::A => 1, + DnsType::NS => 2, + DnsType::MD => 3, + DnsType::MF => 4, + DnsType::CNAME => 5, + DnsType::SOA => 6, + DnsType::MB => 7, + DnsType::MG => 8, + DnsType::MR => 9, + DnsType::NULL => 10, + DnsType::WKS => 11, + DnsType::PTR => 12, + DnsType::HINFO => 13, + DnsType::MINFO => 14, + DnsType::MX => 15, + DnsType::TXT => 16, + DnsType::RP => 17, + DnsType::AFSDB => 18, + DnsType::X25 => 19, + DnsType::ISDN => 20, + DnsType::RT => 21, + DnsType::NSAP => 22, + DnsType::NSAP_PTR => 23, + DnsType::SIG => 24, + DnsType::KEY => 25, + DnsType::PX => 26, + DnsType::GPOS => 27, + DnsType::AAAA => 28, + DnsType::LOC => 29, + DnsType::NXT => 30, + DnsType::EID => 31, + DnsType::NIMLOC => 32, + DnsType::SRV => 33, + DnsType::ATMA => 34, + DnsType::NAPTR => 35, + DnsType::KX => 36, + DnsType::CERT => 37, + DnsType::A6 => 38, + DnsType::DNAME => 39, + DnsType::SINK => 40, + DnsType::OPT => 41, + DnsType::APL => 42, + DnsType::DS => 43, + DnsType::SSHFP => 44, + DnsType::IPSECKEY => 45, + DnsType::RRSIG => 46, + DnsType::NSEC => 47, + DnsType::DNSKEY => 48, + DnsType::DHCID => 49, + DnsType::NSEC3 => 50, + DnsType::NSEC3PARAM => 51, + DnsType::TLSA => 52, + DnsType::SMIMEA => 53, + DnsType::HIP => 55, + DnsType::NINFO => 56, + DnsType::RKEY => 57, + DnsType::TALINK => 58, + DnsType::CDS => 59, + DnsType::CDNSKEY => 60, + DnsType::OPENPGPKEY => 61, + DnsType::CSYNC => 62, + DnsType::ZONEMD => 63, + DnsType::SVCB => 64, + DnsType::HTTPS => 65, + DnsType::SPF => 99, + DnsType::UINFO => 100, + DnsType::UID => 101, + DnsType::GID => 102, + DnsType::UNSPEC => 103, + DnsType::NID => 104, + DnsType::L32 => 105, + DnsType::L64 => 106, + DnsType::LP => 107, + DnsType::EUI48 => 108, + DnsType::EUI64 => 109, + DnsType::TKEY => 249, + DnsType::TSIG => 250, + DnsType::IXFR => 251, + DnsType::AXFR => 252, + DnsType::MAILB => 253, + DnsType::MAILA => 254, + DnsType::ANY => 255, + DnsType::URI => 256, + DnsType::CAA => 257, + DnsType::AVC => 258, + DnsType::DOA => 259, + DnsType::AMTRELAY => 260, + DnsType::TA => 32768, + DnsType::DLV => 32769, + DnsType::Unknown(v) => *v, } } - length -} -fn additional_length(packet: &DnsPacket) -> usize { - let base = 12 + queries_length(packet) + responses_length(packet) + authority_length(packet); - let mut length = 0; - for _ in 0..packet.get_additional_rr_count() { - match DnsResponsePacket::new(&packet.packet()[base + length..]) { - Some(query) => length += query.packet_size(), - None => break, - } + pub fn name(&self) -> &'static str { + match self { + DnsType::A => "A", // 1 + DnsType::NS => "NS", // 2 + DnsType::MD => "MD", // 3 + DnsType::MF => "MF", // 4 + DnsType::CNAME => "CNAME", // 5 + DnsType::SOA => "SOA", // 6 + DnsType::MB => "MB", // 7 + DnsType::MG => "MG", // 8 + DnsType::MR => "MR", // 9 + DnsType::NULL => "NULL", // 10 + DnsType::WKS => "WKS", // 11 + DnsType::PTR => "PTR", // 12 + DnsType::HINFO => "HINFO", // 13 + DnsType::MINFO => "MINFO", // 14 + DnsType::MX => "MX", // 15 + DnsType::TXT => "TXT", // 16 + DnsType::RP => "RP", // 17 + DnsType::AFSDB => "AFSDB", // 18 + DnsType::X25 => "X25", // 19 + DnsType::ISDN => "ISDN", // 20 + DnsType::RT => "RT", // 21 + DnsType::NSAP => "NSAP", // 22 + DnsType::NSAP_PTR => "NSAP_PTR", // 23 + DnsType::SIG => "SIG", // 24 + DnsType::KEY => "KEY", // 25 + DnsType::PX => "PX", // 26 + DnsType::GPOS => "GPOS", // 27 + DnsType::AAAA => "AAAA", // 28 + DnsType::LOC => "LOC", // 29 + DnsType::NXT => "NXT", // 30 + DnsType::EID => "EID", // 31 + DnsType::NIMLOC => "NIMLOC", // 32 + DnsType::SRV => "SRV", // 33 + DnsType::ATMA => "ATMA", // 34 + DnsType::NAPTR => "NAPTR", // 35 + DnsType::KX => "KX", // 36 + DnsType::CERT => "CERT", // 37 + DnsType::A6 => "A6", // 38 + DnsType::DNAME => "DNAME", // 39 + DnsType::SINK => "SINK", // 40 + DnsType::OPT => "OPT", // 41 + DnsType::APL => "APL", // 42 + DnsType::DS => "DS", // 43 + DnsType::SSHFP => "SSHFP", // 44 + DnsType::IPSECKEY => "IPSECKEY", // 45 + DnsType::RRSIG => "RRSIG", // 46 + DnsType::NSEC => "NSEC", // 47 + DnsType::DNSKEY => "DNSKEY", // 48 + DnsType::DHCID => "DHCID", // 49 + DnsType::NSEC3 => "NSEC3", // 50 + DnsType::NSEC3PARAM => "NSEC3PARAM", // 51 + DnsType::TLSA => "TLSA", // 52 + DnsType::SMIMEA => "SMIMEA", // 53 + DnsType::HIP => "HIP", // 55 + DnsType::NINFO => "NINFO", // 56 + DnsType::RKEY => "RKEY", // 57 + DnsType::TALINK => "TALINK", // 58 + DnsType::CDS => "CDS", // 59 + DnsType::CDNSKEY => "CDNSKEY", // 60 + DnsType::OPENPGPKEY => "OPENPGPKEY", // 61 + DnsType::CSYNC => "CSYNC", // 62 + DnsType::ZONEMD => "ZONEMD", // 63 + DnsType::SVCB => "SVCB", // 64 + DnsType::HTTPS => "HTTPS", // 65 + DnsType::SPF => "SPF", // 99 + DnsType::UINFO => "UINFO", // 100 + DnsType::UID => "UID", // 101 + DnsType::GID => "GID", // 102 + DnsType::UNSPEC => "UNSPEC", // 103 + DnsType::NID => "NID", // 104 + DnsType::L32 => "L32", // 105 + DnsType::L64 => "L64", // 106 + DnsType::LP => "LP", // 107 + DnsType::EUI48 => "EUI48", // 108 + DnsType::EUI64 => "EUI64", // 109 + DnsType::TKEY => "TKEY", // 249 + DnsType::TSIG => "TSIG", // 250 + DnsType::IXFR => "IXFR", // 251 + DnsType::AXFR => "AXFR", // 252 + DnsType::MAILB => "MAILB", // 253 + DnsType::MAILA => "MAILA", // 254 + DnsType::ANY => "ANY", // 255 + DnsType::URI => "URI", // 256 + DnsType::CAA => "CAA", // 257 + DnsType::AVC => "AVC", // 258 + DnsType::DOA => "DOA", // 259 + DnsType::AMTRELAY => "AMTRELAY", // 260 + DnsType::TA => "TA", // 32768 + DnsType::DLV => "DLV", // 32769 + _ => "unknown", + } } - length } /// Represents an DNS operation code. @@ -371,22 +446,6 @@ pub enum OpCode { Unassigned(u8), } -impl PrimitiveValues for OpCode { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - match self { - Self::Query => (0,), - Self::InverseQuery => (1,), - Self::Status => (2,), - Self::Reserved => (3,), - Self::Notify => (4,), - Self::Update => (5,), - Self::Dso => (6,), - Self::Unassigned(n) => (*n,), - } - } -} - impl OpCode { pub fn new(value: u8) -> Self { match value { @@ -400,6 +459,30 @@ impl OpCode { _ => Self::Unassigned(value), } } + pub fn value(&self) -> u8 { + match self { + Self::Query => 0, + Self::InverseQuery => 1, + Self::Status => 2, + Self::Reserved => 3, + Self::Notify => 4, + Self::Update => 5, + Self::Dso => 6, + Self::Unassigned(v) => *v, + } + } + pub fn name(&self) -> &'static str { + match self { + Self::Query => "Query", + Self::InverseQuery => "Inverse Query", + Self::Status => "Status", + Self::Reserved => "Reserved", + Self::Notify => "Notify", + Self::Update => "Update", + Self::Dso => "DSO", + Self::Unassigned(_) => "Unassigned", + } + } } /// Represents an DNS return code. @@ -429,35 +512,6 @@ pub enum RetCode { Unassigned(u8), } -impl PrimitiveValues for RetCode { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - match self { - Self::NoError => (0,), - Self::FormErr => (1,), - Self::ServFail => (2,), - Self::NXDomain => (3,), - Self::NotImp => (4,), - Self::Refused => (5,), - Self::YXDomain => (6,), - Self::YXRRSet => (7,), - Self::NXRRSet => (8,), - Self::NotAuth => (9,), - Self::NotZone => (10,), - Self::Dsotypeni => (11,), - Self::BadVers => (16,), - Self::BadKey => (17,), - Self::BadTime => (18,), - Self::BadMode => (19,), - Self::BadName => (20,), - Self::BadAlg => (21,), - Self::BadTrunc => (22,), - Self::BadCookie => (23,), - Self::Unassigned(n) => (*n,), - } - } -} - impl RetCode { pub fn new(value: u8) -> Self { match value { @@ -473,37 +527,169 @@ impl RetCode { 9 => Self::NotAuth, 10 => Self::NotZone, 11 => Self::Dsotypeni, - 16 => Self::BadVers, - 17 => Self::BadKey, - 18 => Self::BadTime, - 19 => Self::BadMode, - 20 => Self::BadName, - 21 => Self::BadAlg, - 22 => Self::BadTrunc, - 23 => Self::BadCookie, + 12 => Self::BadVers, + 13 => Self::BadKey, + 14 => Self::BadTime, + 15 => Self::BadMode, + 16 => Self::BadName, + 17 => Self::BadAlg, + 18 => Self::BadTrunc, + 19 => Self::BadCookie, _ => Self::Unassigned(value), } } + + pub fn value(&self) -> u8 { + match self { + Self::NoError => 0, + Self::FormErr => 1, + Self::ServFail => 2, + Self::NXDomain => 3, + Self::NotImp => 4, + Self::Refused => 5, + Self::YXDomain => 6, + Self::YXRRSet => 7, + Self::NXRRSet => 8, + Self::NotAuth => 9, + Self::NotZone => 10, + Self::Dsotypeni => 11, + Self::BadVers => 12, + Self::BadKey => 13, + Self::BadTime => 14, + Self::BadMode => 15, + Self::BadName => 16, + Self::BadAlg => 17, + Self::BadTrunc => 18, + Self::BadCookie => 19, + Self::Unassigned(v) => *v + } + } + + pub fn name(&self) -> &'static str { + match self { + RetCode::NoError => "No Error", + RetCode::FormErr => "Format Error", + RetCode::ServFail => "Server Failure", + RetCode::NXDomain => "Non-Existent Domain", + RetCode::NotImp => "Not Implemented", + RetCode::Refused => "Query Refused", + RetCode::YXDomain => "Name Exists When It Shouldn't", + RetCode::YXRRSet => "RR Set Exists When It Shouldn't", + RetCode::NXRRSet => "RR Set Doesn't Exist When It Should", + RetCode::NotAuth => "Not Authorized", + RetCode::NotZone => "Name Not Zone", + RetCode::Dsotypeni => "DSO Type NI", + RetCode::BadVers => "Bad Version", + RetCode::BadKey => "Bad Key", + RetCode::BadTime => "Bad Time", + RetCode::BadMode => "Bad Mode", + RetCode::BadName => "Bad Name", + RetCode::BadAlg => "Bad Algorithm", + RetCode::BadTrunc => "Bad Truncation", + RetCode::BadCookie => "Bad Cookie", + RetCode::Unassigned(_) => "Unassigned", + } + } } /// DNS query packet structure. -#[packet] -pub struct DnsQuery { - #[length_fn = "qname_length"] +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DnsQueryPacket { pub qname: Vec, - #[construct_with(u16be)] pub qtype: DnsType, - #[construct_with(u16be)] pub qclass: DnsClass, - #[payload] - pub payload: Vec, + pub payload: Bytes, } -fn qname_length(packet: &DnsQueryPacket) -> usize { - packet.packet().iter().take_while(|w| *w != &0).count() + 1 +impl Packet for DnsQueryPacket { + type Header = (); + fn from_buf(buf: &[u8]) -> Option { + let mut pos = 0; + let mut qname = Vec::new(); + + // Parse the QNAME field + loop { + if pos >= buf.len() { + return None; + } + + let len = buf[pos]; + pos += 1; + qname.push(len); + + if len == 0 { + break; + } + + if pos + len as usize > buf.len() { + return None; + } + + qname.extend_from_slice(&buf[pos..pos + len as usize]); + pos += len as usize; + } + + // Read QTYPE and QCLASS + if pos + 4 > buf.len() { + return None; + } + + let qtype = DnsType::new(u16::from_be_bytes([buf[pos], buf[pos + 1]])); + let qclass = DnsClass::new(u16::from_be_bytes([buf[pos + 2], buf[pos + 3]])); + pos += 4; + + // The rest is stored as payload + let payload = Bytes::copy_from_slice(&buf[pos..]); + + Some(Self { + qname, + qtype, + qclass, + payload, + }) + } + + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(self.qname.len() + 4); + buf.extend_from_slice(&self.qname); + buf.put_u16(self.qtype.value()); + buf.put_u16(self.qclass.value()); + buf.freeze() + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(0..self.header_len()) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + self.qname.len() + 4 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + let header = (); + let payload = self.payload; + (header, payload) + } } -impl DnsQuery { +impl DnsQueryPacket { pub fn get_qname_parsed(&self) -> Result { let name = &self.qname; let mut qname = String::new(); @@ -524,308 +710,473 @@ impl DnsQuery { } Ok(qname) } + pub fn qname_length(&self) -> usize { + self.to_bytes().iter().take_while(|w| *w != &0).count() + 1 + } + pub fn from_buf_mut(buf: &mut &[u8]) -> Option { + let mut qname = Vec::new(); + + loop { + if buf.is_empty() { + return None; + } + let len = buf[0]; + *buf = &buf[1..]; + qname.push(len); + if len == 0 { + break; + } + if buf.len() < len as usize { + return None; + } + qname.extend_from_slice(&buf[..len as usize]); + *buf = &buf[len as usize..]; + } + + if buf.len() < 4 { + return None; + } + + let qtype = DnsType::new(u16::from_be_bytes([buf[0], buf[1]])); + *buf = &buf[2..]; + + let qclass = DnsClass::new(u16::from_be_bytes([buf[0], buf[1]])); + *buf = &buf[2..]; + + let payload = Bytes::copy_from_slice(buf); + + Some(Self { + qname, + qtype, + qclass, + payload, + }) + } } /// DNS response packet structure. -#[packet] -pub struct DnsResponse { - #[length_fn = "rname_length"] - pub rname: Vec, - #[construct_with(u16be)] +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DnsResponsePacket { + pub name_tag: u16be, pub rtype: DnsType, - #[construct_with(u16be)] pub rclass: DnsClass, pub ttl: u32be, pub data_len: u16be, - #[length = "data_len"] pub data: Vec, - #[payload] - pub payload: Vec, + pub payload: Bytes, } -/// Parses and Returns the length of the rname field. -fn rname_length(packet: &DnsResponsePacket) -> usize { - let mut offset = 0; - let mut size = 0; - loop { - let label_len = packet.packet()[offset] as usize; - if label_len == 0 { - size += 1; - break; +impl Packet for DnsResponsePacket { + type Header = (); + fn from_buf(buf: &[u8]) -> Option { + if buf.len() < 12 { + return None; } - if label_len & 0xC0 == 0xC0 { - size += 2; - break; + + let mut pos = 0; + + let name_tag = u16::from_be_bytes([buf[pos], buf[pos + 1]]).into(); + pos += 2; + + let rtype = DnsType::new(u16::from_be_bytes([buf[pos], buf[pos + 1]])); + pos += 2; + + let rclass = DnsClass::new(u16::from_be_bytes([buf[pos], buf[pos + 1]])); + pos += 2; + + let ttl = u32::from_be_bytes([buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]]).into(); + pos += 4; + + let data_len = u16::from_be_bytes([buf[pos], buf[pos + 1]]).into(); + pos += 2; + + let data_len_usize = data_len as usize; + + if buf.len() < pos + data_len_usize { + return None; } - size += label_len + 1; - offset += label_len + 1; + + let data = buf[pos..pos + data_len_usize].to_vec(); + pos += data_len_usize; + + let payload = Bytes::copy_from_slice(&buf[pos..]); + + Some(Self { + name_tag, + rtype, + rclass, + ttl, + data_len, + data, + payload, + }) + } + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = bytes::BytesMut::with_capacity(self.total_len()); + + buf.put_u16(self.name_tag.into()); + buf.put_u16(self.rtype.value()); + buf.put_u16(self.rclass.value()); + buf.put_u32(self.ttl.into()); + buf.put_u16(self.data_len.into()); + buf.put_slice(&self.data); + + buf.freeze() + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(0..self.total_len()) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 12 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + let header = (); + let payload = self.payload; + (header, payload) } - size } -/// Parses the rname field of a DNS packet. -pub fn parse_name(packet: &DnsPacket, coded_name: &Vec) -> Result { - // First follow the path in the rname, except if it starts with a C0 - // then move to using the offsets from the start - let start = packet.packet(); - let mut name = coded_name.as_slice(); - let mut rname = String::new(); - let mut offset: usize = 0; - - loop { - let label_len: u16 = name[offset] as u16; - if label_len == 0 { - break; +impl DnsResponsePacket { + pub fn from_buf_mut(buf: &mut &[u8]) -> Option { + if buf.len() < 12 { + return None; } - if (label_len & 0xC0) == 0xC0 { - let offset1 = ((label_len & 0x3F) as usize) << 8; - let offset2 = name[offset + 1] as usize; - offset = offset1 + offset2; - // now change name - name = start; - continue; - } - if !rname.is_empty() { - rname.push('.'); - } - match str::from_utf8(&name[offset + 1..offset + 1 + label_len as usize]) { - Ok(label) => rname.push_str(label), - Err(e) => return Err(e), + + // name_tag (2) + let name_tag = u16::from_be_bytes([buf[0], buf[1]]).into(); + *buf = &buf[2..]; + + // rtype (2) + let rtype = DnsType::new(u16::from_be_bytes([buf[0], buf[1]])); + *buf = &buf[2..]; + + // rclass (2) + let rclass = DnsClass::new(u16::from_be_bytes([buf[0], buf[1]])); + *buf = &buf[2..]; + + // ttl (4) + let ttl = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]).into(); + *buf = &buf[4..]; + + // data_len (2) + let data_len = u16::from_be_bytes([buf[0], buf[1]]); + *buf = &buf[2..]; + + if buf.len() < data_len as usize { + return None; } - offset += label_len as usize + 1; + + // data (data_len) + let data = buf[..data_len as usize].to_vec(); + *buf = &buf[data_len as usize..]; + + // Remaining bytes are stored as payload + let payload = Bytes::copy_from_slice(buf); + + Some(Self { + name_tag, + rtype, + rclass, + ttl, + data_len: data_len.into(), + data, + payload, + }) } - Ok(rname) } -/// Represents a DNS TXT record. -/// -/// TXT records hold descriptive text. The actual text is stored in the `text` field. -#[packet] -pub struct DnsRrTXT { - pub data_len: u8, - #[length = "data_len"] - pub text: Vec, - #[payload] - pub payload: Vec, +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DnsHeader { + pub id: u16be, + pub is_response: u1, + pub opcode: OpCode, + pub is_authoriative: u1, + pub is_truncated: u1, + pub is_recursion_desirable: u1, + pub is_recursion_available: u1, + pub zero_reserved: u1, + pub is_answer_authenticated: u1, + pub is_non_authenticated_data: u1, + pub rcode: RetCode, + pub query_count: u16be, + pub response_count: u16be, + pub authority_rr_count: u16be, + pub additional_rr_count: u16be, } -/// Represents a DNS SRV record. -/// -/// SRV records are used to specify the location of services by providing a hostname and port number. -#[packet] -pub struct DnsRrSrv { - pub priority: u16be, - pub weight: u16be, - pub port: u16be, - #[length_fn = "target_length"] - pub target: Vec, - #[payload] - pub payload: Vec, +/// Represents a DNS packet. +/// Including its header and all the associated records. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DnsPacket { + pub header: DnsHeader, + pub queries: Vec, + pub responses: Vec, + pub authorities: Vec, + pub additionals: Vec, + pub payload: Bytes, } -fn target_length(packet: &DnsRrSrvPacket) -> usize { - let mut offset = 6; - let mut size = 0; - loop { - let label_len = packet.packet()[offset] as usize; - if label_len == 0 { - size += 1; - break; +impl Packet for DnsPacket { + type Header = (); + fn from_buf(buf: &[u8]) -> Option { + if buf.len() < 12 { + return None; } - if label_len & 0xC0 == 0xC0 { - size += 2; - break; + + let mut cursor = buf; + + // Read DNS header + let id = u16::from_be_bytes([cursor[0], cursor[1]]); + let flags = u16::from_be_bytes([cursor[2], cursor[3]]); + let query_count = u16::from_be_bytes([cursor[4], cursor[5]]); + let response_count = u16::from_be_bytes([cursor[6], cursor[7]]); + let authority_rr_count = u16::from_be_bytes([cursor[8], cursor[9]]); + let additional_rr_count = u16::from_be_bytes([cursor[10], cursor[11]]); + cursor = &cursor[12..]; + + let header = DnsHeader { + id: id.into(), + is_response: ((flags >> 15) & 0x1) as u8, + opcode: OpCode::new(((flags >> 11) & 0xF) as u8), + is_authoriative: ((flags >> 10) & 0x1) as u8, + is_truncated: ((flags >> 9) & 0x1) as u8, + is_recursion_desirable: ((flags >> 8) & 0x1) as u8, + is_recursion_available: ((flags >> 7) & 0x1) as u8, + zero_reserved: ((flags >> 6) & 0x1) as u8, + is_answer_authenticated: ((flags >> 5) & 0x1) as u8, + is_non_authenticated_data: ((flags >> 4) & 0x1) as u8, + rcode: RetCode::new((flags & 0xF) as u8), + query_count: query_count.into(), + response_count: response_count.into(), + authority_rr_count: authority_rr_count.into(), + additional_rr_count: additional_rr_count.into(), + }; + + // Parse each section, passing mutable slices + fn parse_queries(count: usize, buf: &mut &[u8]) -> Option> { + (0..count).map(|_| DnsQueryPacket::from_buf_mut(buf)).collect() + } + + fn parse_responses(count: usize, buf: &mut &[u8]) -> Option> { + (0..count).map(|_| DnsResponsePacket::from_buf_mut(buf)).collect() } - size += label_len + 1; - offset += label_len + 1; + + let mut working_buf = cursor; + + let queries = parse_queries(query_count as usize, &mut working_buf)?; + let responses = parse_responses(response_count as usize, &mut working_buf)?; + let authorities = parse_responses(authority_rr_count as usize, &mut working_buf)?; + let additionals = parse_responses(additional_rr_count as usize, &mut working_buf)?; + + // Remaining data becomes the payload + let payload = Bytes::copy_from_slice(working_buf); + + Some(Self { + header, + queries, + responses, + authorities, + additionals, + payload, + }) } - size -} -/// A structured representation of a Service Name (SRV record content). -/// -/// Parses and holds components of an SRV record's target domain, which includes service instance, service type, protocol, and domain name. -/// SRV record name -#[derive(Debug)] -pub struct SrvName { - pub instance: Option, - pub service: Option, - pub protocol: Option, - pub domain: Option, -} + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } -impl SrvName { - pub fn new(name: &str) -> Self { - let parts: Vec<&str> = name.split('.').collect(); - let (instance, service, protocol, domain) = match parts.as_slice() { - [instance, service, protocol, domain @ ..] - if service.starts_with('_') && protocol.starts_with('_') => - { - ( - Some(String::from(*instance)), - Some(String::from(*service)), - Some(String::from(*protocol)), - Some(String::from(domain.join("."))), - ) - } - [service, protocol, domain @ ..] - if service.starts_with('_') && protocol.starts_with('_') => - { - ( - None, - Some(String::from(*service)), - Some(String::from(*protocol)), - Some(String::from(domain.join("."))), - ) - } - [instance, service, protocol, domain @ ..] => ( - Some(String::from(*instance)), - Some(String::from(*service)), - Some(String::from(*protocol)), - Some(String::from(domain.join("."))), - ), - _ => (None, None, None, None), - }; + fn to_bytes(&self) -> Bytes { + use bytes::{BufMut, BytesMut}; + + let mut buf = BytesMut::with_capacity(self.header_len() + self.payload.len()); - SrvName { - instance, - service, - protocol, - domain, + // DNS Header + let mut flags = 0u16; + flags |= (self.header.is_response as u16) << 15; + flags |= (self.header.opcode.value() as u16) << 11; + flags |= (self.header.is_authoriative as u16) << 10; + flags |= (self.header.is_truncated as u16) << 9; + flags |= (self.header.is_recursion_desirable as u16) << 8; + flags |= (self.header.is_recursion_available as u16) << 7; + flags |= (self.header.zero_reserved as u16) << 6; + flags |= (self.header.is_answer_authenticated as u16) << 5; + flags |= (self.header.is_non_authenticated_data as u16) << 4; + flags |= self.header.rcode.value() as u16; + + buf.put_u16(self.header.id.into()); + buf.put_u16(flags); + buf.put_u16(self.header.query_count.into()); + buf.put_u16(self.header.response_count.into()); + buf.put_u16(self.header.authority_rr_count.into()); + buf.put_u16(self.header.additional_rr_count.into()); + + // Write all queries + for query in &self.queries { + buf.extend_from_slice(&query.to_bytes()); + } + + // Write all responses + for response in &self.responses { + buf.extend_from_slice(&response.to_bytes()); + } + + // Write authorities + for auth in &self.authorities { + buf.extend_from_slice(&auth.to_bytes()); + } + + // Write additionals + for add in &self.additionals { + buf.extend_from_slice(&add.to_bytes()); } + + Bytes::from(buf) } -} -#[test] -fn test_dns_query_packet() { - let packet = DnsPacket::new(b"\x1e\xcb\x01\x20\x00\x01\x00\x00\x00\x00\x00\x01\x0a\x63\x6c\x6f\x75\x64\x66\x6c\x61\x72\x65\x03\x63\x6f\x6d\x00\x00\x01\x00\x01\x00\x00\x29\x10\x00\x00\x00\x00\x00\x00\x00").unwrap(); - assert_eq!(packet.get_id(), 7883); - assert_eq!(packet.get_is_response(), 0); - assert_eq!(packet.get_opcode(), OpCode::Query); - assert_eq!(packet.get_is_authoriative(), 0); - assert_eq!(packet.get_is_truncated(), 0); - assert_eq!(packet.get_is_recursion_desirable(), 1); - assert_eq!(packet.get_is_recursion_available(), 0); - assert_eq!(packet.get_zero_reserved(), 0); - assert_eq!(packet.get_rcode(), RetCode::NoError); - assert_eq!(packet.get_query_count(), 1); - assert_eq!(packet.get_response_count(), 0); - assert_eq!(packet.get_authority_rr_count(), 0); - assert_eq!(packet.get_additional_rr_count(), 1); - assert_eq!(packet.get_queries().len(), 1); - assert_eq!( - packet.get_queries()[0] - .get_qname_parsed() - .unwrap_or(String::new()), - "cloudflare.com" - ); - assert_eq!(packet.get_queries()[0].qtype, DnsTypes::A); - assert_eq!(packet.get_queries()[0].qclass, DnsClasses::IN); - assert_eq!(packet.get_responses().len(), 0); - assert_eq!(packet.get_authorities().len(), 0); - assert_eq!(packet.get_additionals().len(), 1); -} + fn header(&self) -> Bytes { + self.to_bytes().slice(0..12) + } -#[test] -fn test_dns_reponse_packet() { - let packet = DnsPacket::new(b"\x1e\xcb\x81\xa0\x00\x01\x00\x02\x00\x00\x00\x01\x0a\x63\x6c\x6f\x75\x64\x66\x6c\x61\x72\x65\x03\x63\x6f\x6d\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\xc4\x00\x04h\x10\x85\xe5\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\xc4\x00\x04h\x10\x84\xe5\x00\x00)\x04\xd0\x00\x00\x00\x00\x00\x00").unwrap(); - assert_eq!(packet.get_id(), 7883); - assert_eq!(packet.get_is_response(), 1); - assert_eq!(packet.get_opcode(), OpCode::Query); - assert_eq!(packet.get_is_authoriative(), 0); - assert_eq!(packet.get_is_truncated(), 0); - assert_eq!(packet.get_is_recursion_desirable(), 1); - assert_eq!(packet.get_is_recursion_available(), 1); - assert_eq!(packet.get_zero_reserved(), 0); - assert_eq!(packet.get_rcode(), RetCode::NoError); - assert_eq!(packet.get_query_count(), 1); - assert_eq!(packet.get_response_count(), 2); - assert_eq!(packet.get_authority_rr_count(), 0); - assert_eq!(packet.get_additional_rr_count(), 1); - assert_eq!(packet.get_queries().len(), 1); - assert_eq!( - packet.get_queries()[0] - .get_qname_parsed() - .unwrap_or(String::new()), - "cloudflare.com" - ); - assert_eq!(packet.get_queries()[0].qtype, DnsTypes::A); - assert_eq!(packet.get_queries()[0].qclass, DnsClasses::IN); - assert_eq!(packet.get_responses().len(), 2); - assert_eq!(packet.get_responses()[0].rtype, DnsTypes::A); - assert_eq!(packet.get_responses()[0].rclass, DnsClasses::IN); - assert_eq!(packet.get_responses()[0].ttl, 196); - assert_eq!(packet.get_responses()[0].data_len, 4); - assert_eq!( - packet.get_responses()[0].data.as_slice(), - [104, 16, 133, 229] - ); - assert_eq!(packet.get_authorities().len(), 0); - assert_eq!(packet.get_additionals().len(), 1); -} + fn payload(&self) -> Bytes { + self.payload.clone() + } -#[test] -fn test_mdns_response() { - let data = b"\x00\x00\x84\x00\x00\x00\x00\x04\x00\x00\x00\x00\x0b\x5f\x61\x6d\x7a\x6e\x2d\x61\x6c\x65\x78\x61\x04\x5f\x74\x63\x70\x05\x6c\x6f\x63\x61\x6c\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00\x0b\x08\x5f\x73\x65\x72\x76\x69\x63\x65\xc0\x0c\xc0\x2e\x00\x10\x80\x01\x00\x00\x11\x94\x00\x0a\x09\x76\x65\x72\x73\x69\x6f\x6e\x3d\x31\xc0\x2e\x00\x21\x80\x01\x00\x00\x00\x78\x00\x1d\x00\x00\x00\x00\x19\x8f\x14\x61\x76\x73\x2d\x66\x66\x72\x65\x67\x2d\x31\x36\x35\x34\x34\x37\x35\x36\x38\x33\xc0\x1d\xc0\x61\x00\x01\x80\x01\x00\x00\x00\x78\x00\x04\xc0\xa8\x01\x06"; - let packet = DnsPacket::new(data).expect("Failed to parse dns response"); - assert_eq!(packet.get_id(), 0); - assert_eq!(packet.get_is_response(), 1); - assert_eq!(packet.get_opcode(), OpCode::Query); - assert_eq!(packet.get_is_authoriative(), 1); - assert_eq!(packet.get_is_truncated(), 0); - assert_eq!(packet.get_is_recursion_desirable(), 0); - assert_eq!(packet.get_is_recursion_available(), 0); - assert_eq!(packet.get_zero_reserved(), 0); - assert_eq!(packet.get_rcode(), RetCode::NoError); - assert_eq!(packet.get_query_count(), 0); - assert_eq!(packet.get_response_count(), 4); - assert_eq!(packet.get_authority_rr_count(), 0); - assert_eq!(packet.get_additional_rr_count(), 0); - assert_eq!(packet.get_responses().len(), 4); - let responses = packet.get_responses(); - // RR #1 - assert_eq!( - parse_name(&packet, &responses[0].rname).unwrap_or(String::new()), - "_amzn-alexa._tcp.local" - ); - assert_eq!(responses[0].rtype, DnsTypes::PTR); - assert_eq!(responses[0].rclass, DnsClasses::IN); - assert_eq!(responses[0].ttl, 4500); - assert_eq!(responses[0].data_len, 11); - assert_eq!( - parse_name(&packet, &responses[0].data).unwrap_or(String::new()), - "_service._amzn-alexa._tcp.local" - ); - // RR #2 - assert_eq!( - parse_name(&packet, &responses[1].rname).unwrap_or(String::new()), - "_service._amzn-alexa._tcp.local" - ); - assert_eq!(responses[1].rtype, DnsTypes::TXT); - assert_eq!(responses[1].ttl, 4500); - assert_eq!(responses[1].data_len, 10); - let text_rr = DnsRrTXTPacket::new(&responses[1].data).unwrap(); - assert_eq!(text_rr.get_data_len(), 9); - assert_eq!(String::from_utf8(text_rr.get_text()).unwrap(), "version=1"); - // RR #3 - let srv_name = parse_name(&packet, &responses[2].rname).unwrap_or(String::new()); - assert_eq!(srv_name, "_service._amzn-alexa._tcp.local"); - assert_eq!(responses[2].rtype, DnsTypes::SRV); - assert_eq!(responses[2].data_len, 29); - let srv_rr = DnsRrSrvPacket::new(&responses[2].data).unwrap(); - assert_eq!(srv_rr.get_priority(), 0); - assert_eq!(srv_rr.get_weight(), 0); - assert_eq!(srv_rr.get_port(), 6543); - assert_eq!( - parse_name(&packet, &srv_rr.get_target()).unwrap_or(String::new()), - "avs-ffreg-1654475683.local" - ); - let srv = SrvName::new(&srv_name); - assert_eq!(srv.instance, Some(String::from("_service"))); - assert_eq!(srv.service, Some(String::from("_amzn-alexa"))); - assert_eq!(srv.protocol, Some(String::from("_tcp"))); - assert_eq!(srv.domain, Some(String::from("local"))); - // RR #4 - assert_eq!(responses[3].rtype, DnsTypes::A); - assert_eq!(responses[3].data.as_slice(), [192, 168, 1, 6]); + fn header_len(&self) -> usize { + 12 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + let header = (); + let payload = self.payload; + (header, payload) + } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dns_query() { + let bytes = Bytes::from_static(&[ + 0x07, b'b', b'e', b'a', b'c', b'o', b'n', b's', + 0x04, b'g', b'v', b't', b'2', + 0x03, b'c', b'o', b'm', + 0x00, 0x00, 0x41, 0x00, 0x01, // type: HTTPS, class: IN + ]); + let packet = DnsQueryPacket::from_bytes(bytes).unwrap(); + assert_eq!( + packet.qname, + vec![ + 0x07, b'b', b'e', b'a', b'c', b'o', b'n', b's', + 0x04, b'g', b'v', b't', b'2', + 0x03, b'c', b'o', b'm', + 0x00 + ] + ); + assert_eq!(packet.qtype, DnsType::HTTPS); + assert_eq!(packet.qclass, DnsClass::IN); + } + + #[test] + fn test_dns_response() { + let bytes = Bytes::from_static(&[ + 0xc0, 0x0c, // name_tag + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x00, 0x3c, // TTL = 60 + 0x00, 0x04, // data_len = 4 + 0x0d, 0xe2, 0x02, 0x12, // data + ]); + let packet = DnsResponsePacket::from_bytes(bytes).unwrap(); + assert_eq!(packet.rtype, DnsType::A); + assert_eq!(packet.rclass, DnsClass::IN); + assert_eq!(packet.ttl, 60); + assert_eq!(packet.data_len, 4); + assert_eq!(packet.data, vec![13, 226, 2, 18]); + } + + #[test] + fn test_dns_query_packet() { + let bytes = Bytes::from_static(&[ + 0x9b, 0xa0, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x05, b'_', b'l', b'd', b'a', b'p', + 0x04, b'_', b't', b'c', b'p', + 0x02, b'd', b'c', + 0x06, b'_', b'm', b's', b'd', b'c', b's', + 0x05, b'S', b'4', b'D', b'O', b'M', + 0x07, b'P', b'R', b'I', b'V', b'A', b'T', b'E', + 0x00, 0x00, 0x21, 0x00, 0x01, + ]); + let packet = DnsPacket::from_bytes(bytes).unwrap(); + assert_eq!(packet.header.id, 0x9ba0); + assert_eq!(packet.header.is_response, 0); + assert_eq!(packet.header.query_count, 1); + assert_eq!(packet.queries.len(), 1); + assert_eq!( + packet.queries[0].get_qname_parsed().unwrap(), + "_ldap._tcp.dc._msdcs.S4DOM.PRIVATE" + ); + assert_eq!(packet.queries[0].qtype, DnsType::SRV); + assert_eq!(packet.queries[0].qclass, DnsClass::IN); + } + #[test] + fn test_dns_response_packet() { + let bytes = Bytes::from_static(&[ + 0xbc, 0x12, 0x85, 0x80, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x05, b's', b'4', b'd', b'c', b'1', + 0x05, b's', b'a', b'm', b'b', b'a', + 0x08, b'w', b'i', b'n', b'd', b'o', b'w', b's', b'8', + 0x07, b'p', b'r', b'i', b'v', b'a', b't', b'e', + 0x00, 0x00, 0x01, 0x00, 0x01, + 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x00, 0x03, 0x84, + 0x00, 0x04, 0xc0, 0xa8, 0x7a, 0xbd, + ]); + let packet = DnsPacket::from_bytes(bytes).unwrap(); + assert_eq!(packet.header.id, 0xbc12); + assert_eq!(packet.header.is_response, 1); + assert_eq!(packet.header.query_count, 1); + assert_eq!(packet.header.response_count, 1); + assert_eq!(packet.queries.len(), 1); + assert_eq!( + packet.queries[0].get_qname_parsed().unwrap(), + "s4dc1.samba.windows8.private" + ); + assert_eq!(packet.queries[0].qtype, DnsType::A); + assert_eq!(packet.responses[0].rtype, DnsType::A); + assert_eq!(packet.responses[0].rclass, DnsClass::IN); + assert_eq!(packet.responses[0].ttl, 900); + assert_eq!(packet.responses[0].data_len, 4); + assert_eq!(packet.responses[0].data, vec![192, 168, 122, 189]); + } +} \ No newline at end of file diff --git a/nex-packet/src/ethernet.rs b/nex-packet/src/ethernet.rs index eeaf2f7..a4fde73 100644 --- a/nex-packet/src/ethernet.rs +++ b/nex-packet/src/ethernet.rs @@ -1,99 +1,22 @@ //! An ethernet packet abstraction. -use crate::PrimitiveValues; - -use alloc::vec::Vec; use core::fmt; - +use bytes::Bytes; use nex_core::mac::MacAddr; -use nex_macro::packet; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use crate::packet::Packet; + /// Represents the Ethernet header length. pub const ETHERNET_HEADER_LEN: usize = 14; /// Represents the MAC address length. pub const MAC_ADDR_LEN: usize = 6; -/// Represents the Ethernet Header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct EthernetHeader { - /// Destination MAC address - pub destination: MacAddr, - /// Source MAC address - pub source: MacAddr, - /// EtherType - pub ethertype: EtherType, -} - -impl EthernetHeader { - /// Construct an Ethernet header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < ETHERNET_HEADER_LEN { - return Err("Packet is too small for Ethernet header".to_string()); - } - match EthernetPacket::new(packet) { - Some(ethernet_packet) => Ok(EthernetHeader { - destination: ethernet_packet.get_destination(), - source: ethernet_packet.get_source(), - ethertype: ethernet_packet.get_ethertype(), - }), - None => Err("Failed to parse Ethernet packet".to_string()), - } - } - /// Construct an Ethernet header from a EthernetPacket. - pub(crate) fn from_packet(ethernet_packet: &EthernetPacket) -> EthernetHeader { - EthernetHeader { - destination: ethernet_packet.get_destination(), - source: ethernet_packet.get_source(), - ethertype: ethernet_packet.get_ethertype(), - } - } -} - -/// Represents an Ethernet packet. -#[packet] -pub struct Ethernet { - #[construct_with(u8, u8, u8, u8, u8, u8)] - pub destination: MacAddr, - #[construct_with(u8, u8, u8, u8, u8, u8)] - pub source: MacAddr, - #[construct_with(u16)] - pub ethertype: EtherType, - #[payload] - pub payload: Vec, -} - -#[test] -fn ethernet_header_test() { - let mut packet = [0u8; 14]; - { - let mut ethernet_header = MutableEthernetPacket::new(&mut packet[..]).unwrap(); - - let source = MacAddr(0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc); - ethernet_header.set_source(source); - assert_eq!(ethernet_header.get_source(), source); - - let dest = MacAddr(0xde, 0xf0, 0x12, 0x34, 0x45, 0x67); - ethernet_header.set_destination(dest); - assert_eq!(ethernet_header.get_destination(), dest); - - ethernet_header.set_ethertype(EtherType::Ipv6); - assert_eq!(ethernet_header.get_ethertype(), EtherType::Ipv6); - } - - let ref_packet = [ - 0xde, 0xf0, 0x12, 0x34, 0x45, 0x67, /* destination */ - 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, /* source */ - 0x86, 0xdd, /* ethertype */ - ]; - assert_eq!(&ref_packet[..], &packet[..]); -} - /// Represents the Ethernet types. +#[repr(u16)] #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum EtherType { @@ -185,84 +108,251 @@ impl EtherType { EtherType::Unknown(_) => "Unknown", } } -} - -impl PrimitiveValues for EtherType { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { + pub fn value(&self) -> u16 { match *self { - EtherType::Ipv4 => (0x0800,), - EtherType::Arp => (0x0806,), - EtherType::WakeOnLan => (0x0842,), - EtherType::Trill => (0x22F3,), - EtherType::DECnet => (0x6003,), - EtherType::Rarp => (0x8035,), - EtherType::AppleTalk => (0x809B,), - EtherType::Aarp => (0x80F3,), - EtherType::Ipx => (0x8137,), - EtherType::Qnx => (0x8204,), - EtherType::Ipv6 => (0x86DD,), - EtherType::FlowControl => (0x8808,), - EtherType::CobraNet => (0x8819,), - EtherType::Mpls => (0x8847,), - EtherType::MplsMcast => (0x8848,), - EtherType::PppoeDiscovery => (0x8863,), - EtherType::PppoeSession => (0x8864,), - EtherType::Vlan => (0x8100,), - EtherType::PBridge => (0x88a8,), - EtherType::Lldp => (0x88cc,), - EtherType::Ptp => (0x88f7,), - EtherType::Cfm => (0x8902,), - EtherType::QinQ => (0x9100,), - EtherType::Rldp => (0x8899,), - EtherType::Unknown(n) => (n,), + EtherType::Ipv4 => 0x0800, + EtherType::Arp => 0x0806, + EtherType::WakeOnLan => 0x0842, + EtherType::Trill => 0x22F3, + EtherType::DECnet => 0x6003, + EtherType::Rarp => 0x8035, + EtherType::AppleTalk => 0x809B, + EtherType::Aarp => 0x80F3, + EtherType::Ipx => 0x8137, + EtherType::Qnx => 0x8204, + EtherType::Ipv6 => 0x86DD, + EtherType::FlowControl => 0x8808, + EtherType::CobraNet => 0x8819, + EtherType::Mpls => 0x8847, + EtherType::MplsMcast => 0x8848, + EtherType::PppoeDiscovery => 0x8863, + EtherType::PppoeSession => 0x8864, + EtherType::Vlan => 0x8100, + EtherType::PBridge => 0x88a8, + EtherType::Lldp => 0x88cc, + EtherType::Ptp => 0x88f7, + EtherType::Cfm => 0x8902, + EtherType::QinQ => 0x9100, + EtherType::Rldp => 0x8899, + EtherType::Unknown(value) => value, } } } impl fmt::Display for EtherType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// Represents the Ethernet Header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct EthernetHeader { + /// Destination MAC address + pub destination: MacAddr, + /// Source MAC address + pub source: MacAddr, + /// EtherType + pub ethertype: EtherType, +} + +impl EthernetHeader { + /// Construct an Ethernet header from a byte slice. + pub fn from_bytes(packet: Bytes) -> Result { + if packet.len() < ETHERNET_HEADER_LEN { + return Err("Packet is too small for Ethernet header".to_string()); + } + match EthernetPacket::from_bytes(packet) { + Some(ethernet_packet) => Ok(EthernetHeader { + destination: ethernet_packet.get_destination(), + source: ethernet_packet.get_source(), + ethertype: ethernet_packet.get_ethertype(), + }), + None => Err("Failed to parse Ethernet packet".to_string()), + } + } + pub fn to_bytes(&self) -> Bytes { + let mut buf = Vec::with_capacity(ETHERNET_HEADER_LEN); + buf.extend_from_slice(&self.destination.octets()); + buf.extend_from_slice(&self.source.octets()); + buf.extend_from_slice(&self.ethertype.value().to_be_bytes()); + Bytes::from(buf) + } +} + +/// Represents an Ethernet packet. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct EthernetPacket { + /// The Ethernet header. + pub header: EthernetHeader, + pub payload: Bytes, +} + +impl Packet for EthernetPacket { + type Header = EthernetHeader; + + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < ETHERNET_HEADER_LEN { + return None; + } + let destination = MacAddr::from_octets(bytes[0..MAC_ADDR_LEN].try_into().unwrap()); + let source = MacAddr::from_octets(bytes[MAC_ADDR_LEN..2 * MAC_ADDR_LEN].try_into().unwrap()); + let ethertype = EtherType::new(u16::from_be_bytes([bytes[12], bytes[13]])); + let payload = Bytes::copy_from_slice(&bytes[ETHERNET_HEADER_LEN..]); + + Some(EthernetPacket { + header: EthernetHeader { + destination, + source, + ethertype, + }, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + fn to_bytes(&self) -> Bytes { + let mut buf = Vec::with_capacity(ETHERNET_HEADER_LEN + self.payload.len()); + buf.extend_from_slice(&self.header.to_bytes()); + buf.extend_from_slice(&self.payload); + Bytes::from(buf) + } + fn header(&self) -> Bytes { + self.header.to_bytes() + } + fn payload(&self) -> Bytes { + self.payload.clone() + } + fn header_len(&self) -> usize { + ETHERNET_HEADER_LEN + } + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) + } +} + +impl EthernetPacket { + /// Create a new Ethernet packet. + pub fn new(header: EthernetHeader, payload: Bytes) -> Self { + EthernetPacket { header, payload } + } + /// Get the destination MAC address. + pub fn get_destination(&self) -> MacAddr { + self.header.destination + } + + /// Get the source MAC address. + pub fn get_source(&self) -> MacAddr { + self.header.source + } + + /// Get the EtherType. + pub fn get_ethertype(&self) -> EtherType { + self.header.ethertype + } + + pub fn ip_packet(&self) -> Option { + if self.get_ethertype() == EtherType::Ipv4 || self.get_ethertype() == EtherType::Ipv6 { + Some(self.payload.clone()) + } else { + None + } + } +} + +impl fmt::Display for EthernetPacket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "{}", - match self { - EtherType::Ipv4 => "Ipv4", - EtherType::Arp => "Arp", - EtherType::WakeOnLan => "WakeOnLan", - EtherType::Trill => "Trill", - EtherType::DECnet => "DECnet", - EtherType::Rarp => "Rarp", - EtherType::AppleTalk => "AppleTalk", - EtherType::Aarp => "Aarp", - EtherType::Ipx => "Ipx", - EtherType::Qnx => "Qnx", - EtherType::Ipv6 => "Ipv6", - EtherType::FlowControl => "FlowControl", - EtherType::CobraNet => "CobraNet", - EtherType::Mpls => "Mpls", - EtherType::MplsMcast => "MplsMcast", - EtherType::PppoeDiscovery => "PppoeDiscovery", - EtherType::PppoeSession => "PppoeSession", - EtherType::Vlan => "Vlan", - EtherType::PBridge => "PBridge", - EtherType::Lldp => "Lldp", - EtherType::Ptp => "Ptp", - EtherType::Cfm => "Cfm", - EtherType::QinQ => "QinQ", - EtherType::Rldp => "Rldp", - EtherType::Unknown(_) => "unknown", - } + "EthernetPacket {{ destination: {}, source: {}, ethertype: {} }}", + self.get_destination(), + self.get_source(), + self.get_ethertype() ) } } -#[test] -fn ether_type_to_str() { - use std::format; - let ipv4 = EtherType::new(0x0800); - assert_eq!(format!("{}", ipv4), "Ipv4"); - let arp = EtherType::new(0x0806); - assert_eq!(format!("{}", arp), "Arp"); - let unknown = EtherType::new(0x0666); - assert_eq!(format!("{}", unknown), "unknown"); +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use nex_core::mac::MacAddr; + + #[test] + fn test_ethernet_parse_basic() { + let raw = [ + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // dst + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, // src + 0x08, 0x00, // EtherType: IPv4 + 0xde, 0xad, 0xbe, 0xef // Payload (dummy) + ]; + let packet = EthernetPacket::from_bytes(Bytes::copy_from_slice(&raw)).unwrap(); + assert_eq!(packet.get_destination(), MacAddr::from_octets([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff])); + assert_eq!(packet.get_source(), MacAddr::from_octets([0x11, 0x22, 0x33, 0x44, 0x55, 0x66])); + assert_eq!(packet.get_ethertype(), EtherType::Ipv4); + assert_eq!(packet.payload.len(), 4); + } + + #[test] + fn test_ethernet_serialize_roundtrip() { + let original = EthernetPacket { + header: EthernetHeader { + destination: MacAddr::from_octets([1, 2, 3, 4, 5, 6]), + source: MacAddr::from_octets([10, 20, 30, 40, 50, 60]), + ethertype: EtherType::Arp, + }, + payload: Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]), + }; + + let bytes = original.to_bytes(); + let parsed = EthernetPacket::from_bytes(bytes).unwrap(); + + assert_eq!(parsed, original); + } + + #[test] + fn test_ethernet_header_parse_and_serialize() { + let header = EthernetHeader { + destination: MacAddr::from_octets([1, 1, 1, 1, 1, 1]), + source: MacAddr::from_octets([2, 2, 2, 2, 2, 2]), + ethertype: EtherType::Ipv6, + }; + let bytes = header.to_bytes(); + let parsed = EthernetHeader::from_bytes(bytes.clone()).unwrap(); + + assert_eq!(header, parsed); + assert_eq!(bytes.len(), ETHERNET_HEADER_LEN); + } + + #[test] + fn test_ethernet_parse_too_short() { + let short = Bytes::from_static(&[0, 1, 2, 3]); // insufficient length + assert!(EthernetPacket::from_bytes(short).is_none()); + } + + #[test] + fn test_ethernet_unknown_ethertype() { + let raw = [ + 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, + 0xde, 0xad, // Unknown EtherType + 0x00, 0x11, 0x22, 0x33 + ]; + let packet = EthernetPacket::from_bytes(Bytes::copy_from_slice(&raw)).unwrap(); + match packet.get_ethertype() { + EtherType::Unknown(val) => assert_eq!(val, 0xdead), + _ => panic!("Expected unknown EtherType"), + } + } } diff --git a/nex-packet/src/flowcontrol.rs b/nex-packet/src/flowcontrol.rs new file mode 100644 index 0000000..f8afecc --- /dev/null +++ b/nex-packet/src/flowcontrol.rs @@ -0,0 +1,133 @@ +//! Ethernet Flow Control \[IEEE 802.3x\] abstraction. +use core::fmt; + +use bytes::{Buf, BufMut, Bytes}; +use nex_core::bitfield::u16be; + +use crate::packet::Packet; + +/// Represents the opcode field in an Ethernet Flow Control packet. +/// +/// Flow control opcodes are defined in IEEE 802.3x +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u16)] +pub enum FlowControlOpcode { + Pause = 0x0001, + Unknown(u16), +} + +impl FlowControlOpcode { + pub fn new(value: u16) -> Self { + match value { + 0x0001 => FlowControlOpcode::Pause, + other => FlowControlOpcode::Unknown(other), + } + } + + pub fn value(&self) -> u16 { + match *self { + FlowControlOpcode::Pause => 0x0001, + FlowControlOpcode::Unknown(v) => v, + } + } +} + +impl fmt::Display for FlowControlOpcode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", match self { + FlowControlOpcode::Pause => "pause", + FlowControlOpcode::Unknown(_) => "unknown", + }) + } +} + +/// Represents an Ethernet Flow Control packet defined by IEEE 802.3x. +/// +/// [EtherTypes::FlowControl](crate::ethernet::EtherTypes::FlowControl) ethertype (0x8808). +pub struct FlowControlPacket { + pub command: FlowControlOpcode, + pub quanta: u16be, + pub payload: Bytes, +} + +impl Packet for FlowControlPacket { + type Header = (); + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < 4 { + return None; + } + + let command = FlowControlOpcode::new(bytes.get_u16()); + let quanta = bytes.get_u16(); + + // Payload including padding; its contents are not specified by the standard + let payload = Bytes::copy_from_slice(bytes); + + Some(Self { + command, + quanta: quanta.into(), + payload, + }) + } + + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = bytes::BytesMut::with_capacity(4 + self.payload.len()); + + buf.put_u16(self.command.value()); + buf.put_u16(self.quanta.into()); + buf.put_slice(&self.payload); + + buf.freeze() + } + fn header(&self) -> Bytes { + let mut buf = bytes::BytesMut::with_capacity(4); + + buf.put_u16(self.command.value()); + buf.put_u16(self.quanta.into()); + + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 4 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.to_bytes()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn flowcontrol_pause_test() { + let packet = Bytes::from_static(&[ + 0x00, 0x01, // Opcode: Pause + 0x12, 0x34, // Quanta: 0x1234 + 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, // Padding ... + ]); + + let fc_packet = FlowControlPacket::from_bytes(packet.clone()).unwrap(); + assert_eq!(fc_packet.command, FlowControlOpcode::Pause); + assert_eq!(fc_packet.quanta, 0x1234); + assert_eq!(fc_packet.to_bytes(), packet); + } +} diff --git a/nex-packet/src/frame.rs b/nex-packet/src/frame.rs index 9fb8fa1..1e644d9 100644 --- a/nex-packet/src/frame.rs +++ b/nex-packet/src/frame.rs @@ -1,21 +1,10 @@ +use bytes::Bytes; use nex_core::mac::MacAddr; -use nex_macro_helper::packet::Packet; -use crate::arp::{ArpHeader, ArpPacket}; -use crate::ethernet::EthernetHeader; -use crate::ethernet::{EtherType, EthernetPacket, MutableEthernetPacket}; -use crate::icmp::{IcmpHeader, IcmpPacket}; -use crate::icmpv6::{Icmpv6Header, Icmpv6Packet}; -use crate::ip::IpNextLevelProtocol; -use crate::ipv4::{Ipv4Header, Ipv4Packet}; -use crate::ipv6::{Ipv6Header, Ipv6Packet}; -use crate::tcp::{TcpHeader, TcpPacket}; -use crate::udp::{UdpHeader, UdpPacket}; +use crate::{arp::{ArpHeader, ArpPacket}, ethernet::{EtherType, EthernetHeader, EthernetPacket}, icmp::{IcmpHeader, IcmpPacket}, icmpv6::{Icmpv6Header, Icmpv6Packet}, ip::IpNextProtocol, ipv4::{Ipv4Header, Ipv4Packet}, ipv6::{Ipv6Header, Ipv6Packet}, packet::Packet, tcp::{TcpHeader, TcpPacket}, udp::{UdpHeader, UdpPacket}}; + -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; -/// Represents a data link layer. #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct DatalinkLayer { @@ -23,7 +12,6 @@ pub struct DatalinkLayer { pub arp: Option, } -/// Represents an IP layer. #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct IpLayer { @@ -33,7 +21,6 @@ pub struct IpLayer { pub icmpv6: Option, } -/// Represents a transport layer. #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct TransportLayer { @@ -41,172 +28,129 @@ pub struct TransportLayer { pub udp: Option, } -/// Parse options. #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ParseOption { - /// Parse from IP packet. pub from_ip_packet: bool, - /// Offset of the packet. - /// If `from_ip_packet` is true, this value is the offset of the IP packet. pub offset: usize, } -impl ParseOption { - /// Construct a new ParseOption. - pub fn new(from_ip_packet: bool, offset: usize) -> ParseOption { - ParseOption { - from_ip_packet, - offset, - } - } -} - impl Default for ParseOption { fn default() -> Self { - ParseOption { - from_ip_packet: false, - offset: 0, - } + Self { from_ip_packet: false, offset: 0 } } } -/// Represents a packet frame. #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Frame { - /// The datalink layer. pub datalink: Option, - /// The IP layer. pub ip: Option, - /// The transport layer. pub transport: Option, - /// Rest of the packet that could not be parsed as a header. (Usually payload) - pub payload: Vec, - /// Packet length. + pub payload: Bytes, pub packet_len: usize, } impl Frame { - /// Construct a frame from a byte slice. - pub fn from_bytes(packet: &[u8], option: ParseOption) -> Frame { - parse_packet(packet, option) - } -} + pub fn from_buf(packet: &[u8], option: ParseOption) -> Option { + let mut frame = Frame { + datalink: None, + ip: None, + transport: None, + payload: Bytes::new(), + packet_len: packet.len(), + }; -fn create_dummy_ethernet_packet(packet: &[u8], offset: usize) -> Vec { - let mut buf: Vec = vec![0u8; packet.len() - offset + 14]; - match MutableEthernetPacket::new(&mut buf[..]) { - Some(mut fake_ethernet_frame) => match Ipv4Packet::new(&packet[offset..]) { - Some(ipv4_packet) => { - let version: u8 = ipv4_packet.get_version(); - if version == 4 { - fake_ethernet_frame.set_destination(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_source(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_ethertype(EtherType::Ipv4); - fake_ethernet_frame.set_payload(&packet[offset..]); - } else if version == 6 { - fake_ethernet_frame.set_destination(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_source(MacAddr(0, 0, 0, 0, 0, 0)); - fake_ethernet_frame.set_ethertype(EtherType::Ipv6); - fake_ethernet_frame.set_payload(&packet[offset..]); - } - return fake_ethernet_frame.packet().to_vec(); - } - None => { - return Vec::new(); - } - }, - None => { - return Vec::new(); + let ethernet_packet = if option.from_ip_packet { + create_dummy_ethernet_packet(packet, option.offset)? + } else { + EthernetPacket::from_buf(packet)? + }; + + let ether_type = ethernet_packet.get_ethertype(); + let (ether_header, ether_payload) = ethernet_packet.into_parts(); + frame.datalink = Some(DatalinkLayer { + ethernet: Some(ether_header), + arp: None, + }); + + match ether_type { + EtherType::Ipv4 => parse_ipv4_packet(ether_payload, &mut frame), + EtherType::Ipv6 => parse_ipv6_packet(ether_payload, &mut frame), + EtherType::Arp => parse_arp_packet(ether_payload, &mut frame), + _ => {} } + + Some(frame) } } -fn parse_packet(packet: &[u8], option: ParseOption) -> Frame { - let mut frame = Frame { - datalink: None, - ip: None, - transport: None, - payload: Vec::new(), - packet_len: packet.len(), - }; - let dummy_ethernet_packet: Vec; - let ethernet_packet = if option.from_ip_packet { - dummy_ethernet_packet = create_dummy_ethernet_packet(packet, option.offset); - match EthernetPacket::new(&dummy_ethernet_packet) { - Some(ethernet_packet) => ethernet_packet, - None => { - return frame; - } - } +pub fn create_dummy_ethernet_packet(packet: &[u8], offset: usize) -> Option { + if offset >= packet.len() { + return None; + } + + let payload = &packet[offset..]; + + let ethertype = if Ipv4Packet::from_buf(payload).is_some() { + EtherType::Ipv4 + } else if Ipv6Packet::from_buf(payload).is_some() { + EtherType::Ipv6 } else { - match EthernetPacket::new(packet) { - Some(ethernet_packet) => ethernet_packet, - None => { - return frame; - } - } + return None; }; - let ethernet_header = EthernetHeader::from_packet(ðernet_packet); - frame.datalink = Some(DatalinkLayer { - ethernet: Some(ethernet_header), - arp: None, - }); - match ethernet_packet.get_ethertype() { - EtherType::Ipv4 => { - parse_ipv4_packet(ðernet_packet, &mut frame); - } - EtherType::Ipv6 => { - parse_ipv6_packet(ðernet_packet, &mut frame); - } - EtherType::Arp => { - parse_arp_packet(ðernet_packet, &mut frame); - } - _ => {} - } - frame + + let header = EthernetHeader { + destination: MacAddr::zero(), + source: MacAddr::zero(), + ethertype, + }; + + Some(EthernetPacket { + header, + payload: Bytes::copy_from_slice(payload), + }) } -fn parse_arp_packet(ethernet_packet: &EthernetPacket, frame: &mut Frame) { - match ArpPacket::new(ethernet_packet.payload()) { +fn parse_arp_packet(packet: Bytes, frame: &mut Frame) { + match ArpPacket::from_buf(&packet) { Some(arp_packet) => { - let arp_header = ArpHeader::from_packet(&arp_packet); if let Some(datalink) = &mut frame.datalink { - datalink.arp = Some(arp_header); + datalink.arp = Some(arp_packet.header); } } None => { if let Some(datalink) = &mut frame.datalink { datalink.arp = None; } - frame.payload = ethernet_packet.payload().to_vec(); + frame.payload = packet; } } } -fn parse_ipv4_packet(ethernet_packet: &EthernetPacket, frame: &mut Frame) { - match Ipv4Packet::new(ethernet_packet.payload()) { +fn parse_ipv4_packet(packet: Bytes, frame: &mut Frame) { + match Ipv4Packet::from_bytes(packet) { Some(ipv4_packet) => { - let ipv4_header = Ipv4Header::from_packet(&ipv4_packet); + let (header, payload) = ipv4_packet.into_parts(); + let proto = header.next_level_protocol; frame.ip = Some(IpLayer { - ipv4: Some(ipv4_header), + ipv4: Some(header), ipv6: None, icmp: None, icmpv6: None, }); - match ipv4_packet.get_next_level_protocol() { - IpNextLevelProtocol::Tcp => { - parse_ipv4_tcp_packet(&ipv4_packet, frame); + match proto { + IpNextProtocol::Tcp => { + parse_tcp_packet(payload, frame); } - IpNextLevelProtocol::Udp => { - parse_ipv4_udp_packet(&ipv4_packet, frame); + IpNextProtocol::Udp => { + parse_udp_packet(payload, frame); } - IpNextLevelProtocol::Icmp => { - parse_icmp_packet(&ipv4_packet, frame); + IpNextProtocol::Icmp => { + parse_icmp_packet(payload, frame); } _ => { - frame.payload = ipv4_packet.payload().to_vec(); + frame.payload = payload; } } } @@ -221,28 +165,29 @@ fn parse_ipv4_packet(ethernet_packet: &EthernetPacket, frame: &mut Frame) { } } -fn parse_ipv6_packet(ethernet_packet: &EthernetPacket, frame: &mut Frame) { - match Ipv6Packet::new(ethernet_packet.payload()) { +fn parse_ipv6_packet(packet: Bytes, frame: &mut Frame) { + match Ipv6Packet::from_bytes(packet) { Some(ipv6_packet) => { - let ipv6_header = Ipv6Header::from_packet(&ipv6_packet); + let (header, payload) = ipv6_packet.into_parts(); + let proto = header.next_header; frame.ip = Some(IpLayer { ipv4: None, - ipv6: Some(ipv6_header), + ipv6: Some(header), icmp: None, icmpv6: None, }); - match ipv6_packet.get_next_header() { - IpNextLevelProtocol::Tcp => { - parse_ipv6_tcp_packet(&ipv6_packet, frame); + match proto { + IpNextProtocol::Tcp => { + parse_tcp_packet(payload, frame); } - IpNextLevelProtocol::Udp => { - parse_ipv6_udp_packet(&ipv6_packet, frame); + IpNextProtocol::Udp => { + parse_udp_packet(payload, frame); } - IpNextLevelProtocol::Icmpv6 => { - parse_icmpv6_packet(&ipv6_packet, frame); + IpNextProtocol::Icmpv6 => { + parse_icmpv6_packet(payload, frame); } _ => { - frame.payload = ipv6_packet.payload().to_vec(); + frame.payload = payload; } } } @@ -257,118 +202,78 @@ fn parse_ipv6_packet(ethernet_packet: &EthernetPacket, frame: &mut Frame) { } } -fn parse_ipv4_tcp_packet(ipv4_packet: &Ipv4Packet, frame: &mut Frame) { - match TcpPacket::new(ipv4_packet.payload()) { - Some(tcp_packet) => { - let tcp_header = TcpHeader::from_packet(&tcp_packet); - frame.transport = Some(TransportLayer { - tcp: Some(tcp_header), - udp: None, - }); - frame.payload = tcp_packet.payload().to_vec(); - } - None => { - frame.transport = Some(TransportLayer { - tcp: None, - udp: None, - }); - frame.payload = ipv4_packet.payload().to_vec(); - } - } -} - -fn parse_ipv6_tcp_packet(ipv6_packet: &Ipv6Packet, frame: &mut Frame) { - match TcpPacket::new(ipv6_packet.payload()) { +fn parse_tcp_packet(packet: Bytes, frame: &mut Frame) { + match TcpPacket::from_bytes(packet.clone()) { Some(tcp_packet) => { - let tcp_header = TcpHeader::from_packet(&tcp_packet); + let (header, payload) = tcp_packet.into_parts(); frame.transport = Some(TransportLayer { - tcp: Some(tcp_header), + tcp: Some(header), udp: None, }); - frame.payload = tcp_packet.payload().to_vec(); - } - None => { - frame.transport = Some(TransportLayer { - tcp: None, - udp: None, - }); - frame.payload = ipv6_packet.payload().to_vec(); - } - } -} - -fn parse_ipv4_udp_packet(ipv4_packet: &Ipv4Packet, frame: &mut Frame) { - match UdpPacket::new(ipv4_packet.payload()) { - Some(udp_packet) => { - let udp_header = UdpHeader::from_packet(&udp_packet); - frame.transport = Some(TransportLayer { - tcp: None, - udp: Some(udp_header), - }); - frame.payload = udp_packet.payload().to_vec(); + frame.payload = payload; } None => { frame.transport = Some(TransportLayer { tcp: None, udp: None, }); - frame.payload = ipv4_packet.payload().to_vec(); + frame.payload = packet; } } } -fn parse_ipv6_udp_packet(ipv6_packet: &Ipv6Packet, frame: &mut Frame) { - match UdpPacket::new(ipv6_packet.payload()) { +fn parse_udp_packet(packet: Bytes, frame: &mut Frame) { + match UdpPacket::from_bytes(packet.clone()) { Some(udp_packet) => { - let udp_header = UdpHeader::from_packet(&udp_packet); + let (header, payload) = udp_packet.into_parts(); frame.transport = Some(TransportLayer { tcp: None, - udp: Some(udp_header), + udp: Some(header), }); - frame.payload = udp_packet.payload().to_vec(); + frame.payload = payload; } None => { frame.transport = Some(TransportLayer { tcp: None, udp: None, }); - frame.payload = ipv6_packet.payload().to_vec(); + frame.payload = packet; } } } -fn parse_icmp_packet(ipv4_packet: &Ipv4Packet, frame: &mut Frame) { - match IcmpPacket::new(ipv4_packet.payload()) { +fn parse_icmp_packet(packet: Bytes, frame: &mut Frame) { + match IcmpPacket::from_bytes(packet.clone()) { Some(icmp_packet) => { - let icmp_header = IcmpHeader::from_packet(&icmp_packet); + let (header, payload) = icmp_packet.into_parts(); if let Some(ip) = &mut frame.ip { - ip.icmp = Some(icmp_header); + ip.icmp = Some(header); } - frame.payload = icmp_packet.payload().to_vec(); + frame.payload = payload; } None => { if let Some(ip) = &mut frame.ip { ip.icmp = None; } - frame.payload = ipv4_packet.payload().to_vec(); + frame.payload = packet; } } } -fn parse_icmpv6_packet(ipv6_packet: &Ipv6Packet, frame: &mut Frame) { - match Icmpv6Packet::new(ipv6_packet.payload()) { +fn parse_icmpv6_packet(packet: Bytes, frame: &mut Frame) { + match Icmpv6Packet::from_bytes(packet.clone()) { Some(icmpv6_packet) => { - let icmpv6_header = Icmpv6Header::from_packet(&icmpv6_packet); + let (header, payload) = icmpv6_packet.into_parts(); if let Some(ip) = &mut frame.ip { - ip.icmpv6 = Some(icmpv6_header); + ip.icmpv6 = Some(header); } - frame.payload = icmpv6_packet.payload().to_vec(); + frame.payload = payload; } None => { if let Some(ip) = &mut frame.ip { ip.icmpv6 = None; } - frame.payload = ipv6_packet.payload().to_vec(); + frame.payload = packet; } } } diff --git a/nex-packet/src/gre.rs b/nex-packet/src/gre.rs index 0d5a969..136cfae 100644 --- a/nex-packet/src/gre.rs +++ b/nex-packet/src/gre.rs @@ -1,18 +1,14 @@ //! GRE Packet abstraction. -#[cfg(test)] -use crate::Packet; - -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; +use bytes::{Buf, Bytes}; +use nex_core::bitfield::{u1, u16be, u3, u32be, u5}; +use crate::packet::Packet; /// GRE (Generic Routing Encapsulation) Packet. /// /// See RFCs 1701, 2784, 2890, 7676, 2637 -#[packet] -pub struct Gre { +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GrePacket { pub checksum_present: u1, pub routing_present: u1, pub key_present: u1, @@ -21,101 +17,272 @@ pub struct Gre { pub recursion_control: u3, pub zero_flags: u5, pub version: u3, - pub protocol_type: u16be, // 0x800 for ipv4 [basically an ethertype - #[length_fn = "gre_checksum_length"] - pub checksum: Vec, - #[length_fn = "gre_offset_length"] - pub offset: Vec, - #[length_fn = "gre_key_length"] - pub key: Vec, - #[length_fn = "gre_sequence_length"] - pub sequence: Vec, - #[length_fn = "gre_routing_length"] + pub protocol_type: u16be, // 0x800 for IPv4 + pub checksum: Vec, + pub offset: Vec, + pub key: Vec, + pub sequence: Vec, pub routing: Vec, - #[payload] - pub payload: Vec, + pub payload: Bytes, } -fn gre_checksum_length(gre: &GrePacket) -> usize { - (gre.get_checksum_present() | gre.get_routing_present()) as usize * 2 -} +impl Packet for GrePacket { + type Header = (); -fn gre_offset_length(gre: &GrePacket) -> usize { - (gre.get_checksum_present() | gre.get_routing_present()) as usize * 2 -} + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.remaining() < 4 { + return None; + } -fn gre_key_length(gre: &GrePacket) -> usize { - gre.get_key_present() as usize * 4 -} + let flags = bytes.get_u16(); + let protocol_type = bytes.get_u16(); -fn gre_sequence_length(gre: &GrePacket) -> usize { - gre.get_sequence_present() as usize * 4 -} + let checksum_present = ((flags >> 15) & 0x1) as u1; + let routing_present = ((flags >> 14) & 0x1) as u1; + let key_present = ((flags >> 13) & 0x1) as u1; + let sequence_present = ((flags >> 12) & 0x1) as u1; + let strict_source_route = ((flags >> 11) & 0x1) as u1; + let recursion_control = ((flags >> 8) & 0x7) as u3; + let zero_flags = ((flags >> 3) & 0x1f) as u5; + let version = (flags & 0x7) as u3; + + // Retrieve optional fields in order + let mut checksum = Vec::new(); + let mut offset = Vec::new(); + let mut key = Vec::new(); + let mut sequence = Vec::new(); + let routing = Vec::new(); -fn gre_routing_length(gre: &GrePacket) -> usize { - if 0 == gre.get_routing_present() { - 0 - } else { - panic!("Source routed GRE packets not supported") + if checksum_present != 0 || routing_present != 0 { + if bytes.remaining() < 4 { + return None; + } + checksum.push(bytes.get_u16()); + offset.push(bytes.get_u16()); + } + + if key_present != 0 { + if bytes.remaining() < 4 { + return None; + } + key.push(bytes.get_u32()); + } + + if sequence_present != 0 { + if bytes.remaining() < 4 { + return None; + } + sequence.push(bytes.get_u32()); + } + + if routing_present != 0 { + // Not implemented for this crate + panic!("Source routed GRE packets not supported"); + } + + let payload = Bytes::copy_from_slice(bytes); + + Some(Self { + checksum_present, + routing_present, + key_present, + sequence_present, + strict_source_route, + recursion_control, + zero_flags, + version, + protocol_type: protocol_type.into(), + checksum, + offset, + key, + sequence, + routing, + payload, + }) } -} -/// `u16be`, but we can't use that directly in a `Vec` :( -#[packet] -pub struct U16BE { - number: u16be, - #[length = "0"] - #[payload] - unused: Vec, -} + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + use bytes::{BufMut, BytesMut}; + + let mut buf = BytesMut::with_capacity(self.header_len()); + + // Build the flags field + let mut flags: u16 = 0; + flags |= (self.checksum_present as u16) << 15; + flags |= (self.routing_present as u16) << 14; + flags |= (self.key_present as u16) << 13; + flags |= (self.sequence_present as u16) << 12; + flags |= (self.strict_source_route as u16) << 11; + flags |= (self.recursion_control as u16) << 8; + flags |= (self.zero_flags as u16) << 3; + flags |= self.version as u16; + + buf.put_u16(flags); + buf.put_u16(self.protocol_type.into()); + + if self.checksum_present != 0 || self.routing_present != 0 { + for c in &self.checksum { + buf.put_u16(*c); + } + for o in &self.offset { + buf.put_u16(*o); + } + } + + if self.key_present != 0 { + for k in &self.key { + buf.put_u32(*k); + } + } + + if self.sequence_present != 0 { + for s in &self.sequence { + buf.put_u32(*s); + } + } + + // Panic if routing_present is set (not supported by this implementation) + if self.routing_present != 0 { + panic!("to_bytes does not support source routed GRE packets"); + } + + buf.put_slice(&self.payload); + + buf.freeze() + } + fn header(&self) -> Bytes { + use bytes::{BufMut, BytesMut}; + + let mut buf = BytesMut::with_capacity(self.header_len()); + + // Build the flags field + let mut flags: u16 = 0; + flags |= (self.checksum_present as u16) << 15; + flags |= (self.routing_present as u16) << 14; + flags |= (self.key_present as u16) << 13; + flags |= (self.sequence_present as u16) << 12; + flags |= (self.strict_source_route as u16) << 11; + flags |= (self.recursion_control as u16) << 8; + flags |= (self.zero_flags as u16) << 3; + flags |= self.version as u16; + + buf.put_u16(flags); + buf.put_u16(self.protocol_type.into()); -/// `u32be`, but we can't use that directly in a `Vec` :( -#[packet] -pub struct U32BE { - number: u32be, - #[length = "0"] - #[payload] - unused: Vec, + if self.checksum_present != 0 || self.routing_present != 0 { + for c in &self.checksum { + buf.put_u16(*c); + } + for o in &self.offset { + buf.put_u16(*o); + } + } + + if self.key_present != 0 { + for k in &self.key { + buf.put_u32(*k); + } + } + + if self.sequence_present != 0 { + for s in &self.sequence { + buf.put_u32(*s); + } + } + + // Panic if routing_present is set (not supported by this implementation) + if self.routing_present != 0 { + panic!("header does not support source routed GRE packets"); + } + + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 4 // base header: 2 bytes flags + 2 bytes protocol_type + + self.checksum_length() + + self.offset_length() + + self.key_length() + + self.sequence_length() + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.to_bytes()) + } } -#[test] -fn gre_packet_test() { - let mut packet = [0u8; 4]; - { - let mut gre_packet = MutableGrePacket::new(&mut packet[..]).unwrap(); - gre_packet.set_protocol_type(0x0800); - assert_eq!(gre_packet.payload().len(), 0); +impl GrePacket { + pub fn checksum_length(&self) -> usize { + (self.checksum_present | self.routing_present) as usize * 2 } - let ref_packet = [ - 0x00, /* no flags */ - 0x00, /* no flags, version 0 */ - 0x08, /* protocol 0x0800 */ - 0x00, - ]; + pub fn offset_length(&self) -> usize { + (self.checksum_present | self.routing_present) as usize * 2 + } + + pub fn key_length(&self) -> usize { + self.key_present as usize * 4 + } - assert_eq!(&ref_packet[..], &packet[..]); + pub fn sequence_length(&self) -> usize { + self.sequence_present as usize * 4 + } + + pub fn routing_length(&self) -> usize { + if 0 == self.routing_present { + 0 + } else { + panic!("Source routed GRE packets not supported") + } + } } -#[test] -fn gre_checksum_test() { - let mut packet = [0u8; 8]; - { - let mut gre_packet = MutableGrePacket::new(&mut packet[..]).unwrap(); - gre_packet.set_checksum_present(1); - assert_eq!(gre_packet.payload().len(), 0); - assert_eq!(gre_packet.get_checksum().len(), 1); - assert_eq!(gre_packet.get_offset().len(), 1); - } - - let ref_packet = [ - 0x80, /* checksum on */ - 0x00, /* no flags, version 0 */ - 0x00, /* protocol 0x0000 */ - 0x00, 0x00, /* 16 bits of checksum */ - 0x00, 0x00, /* 16 bits of offset */ - 0x00, - ]; - - assert_eq!(&ref_packet[..], &packet[..]); +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn gre_packet_test() { + let packet = Bytes::from_static(&[ + 0x00, /* no flags */ + 0x00, /* no flags, version 0 */ + 0x08, /* protocol 0x0800 */ + 0x00, + ]); + + let gre_packet = GrePacket::from_buf(&mut packet.clone()).unwrap(); + + assert_eq!(&gre_packet.to_bytes(), &packet); + } + + #[test] + fn gre_checksum_test() { + let packet = Bytes::from_static(&[ + 0x80, /* checksum on */ + 0x00, /* no flags, version 0 */ + 0x00, /* protocol 0x0000 */ + 0x00, 0x00, /* 16 bits of checksum */ + 0x00, 0x00, /* 16 bits of offset */ + 0x00, + ]); + + let gre_packet = GrePacket::from_buf(&mut packet.clone()).unwrap(); + + assert_eq!(&gre_packet.to_bytes(), &packet); + } } diff --git a/nex-packet/src/icmp.rs b/nex-packet/src/icmp.rs index 386f95c..1f8972f 100644 --- a/nex-packet/src/icmp.rs +++ b/nex-packet/src/icmp.rs @@ -1,58 +1,21 @@ //! An ICMP packet abstraction. - -use crate::PrimitiveValues; - -use alloc::vec::Vec; - -use crate::ethernet::ETHERNET_HEADER_LEN; +use crate::{ethernet::ETHERNET_HEADER_LEN, packet::Packet}; use crate::ipv4::IPV4_HEADER_LEN; -use nex_macro::packet; -use nex_macro_helper::types::*; +use bytes::{BufMut, Bytes, BytesMut}; +use nex_core::bitfield::u16be; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// ICMPv4 Header Length. -pub const ICMPV4_HEADER_LEN: usize = echo_request::MutableEchoRequestPacket::minimum_packet_size(); +/// ICMP Common Header Length. +pub const ICMP_COMMON_HEADER_LEN: usize = 4; +/// ICMPv4 Header Length. Including the common header (4 bytes) and the type specific header (4 bytes). +pub const ICMPV4_HEADER_LEN: usize = 8; /// ICMPv4 Minimum Packet Length. pub const ICMPV4_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV4_HEADER_LEN + ICMPV4_HEADER_LEN; /// ICMPv4 IP Packet Length. pub const ICMPV4_IP_PACKET_LEN: usize = IPV4_HEADER_LEN + ICMPV4_HEADER_LEN; -/// Represents the ICMPv4 header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct IcmpHeader { - pub icmp_type: IcmpType, - pub icmp_code: IcmpCode, - pub checksum: u16be, -} - -impl IcmpHeader { - /// Construct an ICMPv4 header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < ICMPV4_HEADER_LEN { - return Err("Packet is too small for ICMPv4 header".to_string()); - } - match IcmpPacket::new(packet) { - Some(icmp_packet) => Ok(IcmpHeader { - icmp_type: icmp_packet.get_icmp_type(), - icmp_code: icmp_packet.get_icmp_code(), - checksum: icmp_packet.get_checksum(), - }), - None => Err("Failed to parse ICMPv4 packet".to_string()), - } - } - /// Construct an ICMPv4 header from a IcmpPacket. - pub(crate) fn from_packet(icmp_packet: &IcmpPacket) -> IcmpHeader { - IcmpHeader { - icmp_type: icmp_packet.get_icmp_type(), - icmp_code: icmp_packet.get_icmp_code(), - checksum: icmp_packet.get_checksum(), - } - } -} - /// Represents the "ICMP type" header field. #[repr(u8)] #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -121,70 +84,66 @@ impl IcmpType { } } /// Get the name of the ICMP type - pub fn name(&self) -> String { + pub fn name(&self) -> &'static str { match *self { - IcmpType::EchoReply => String::from("Echo Reply"), - IcmpType::DestinationUnreachable => String::from("Destination Unreachable"), - IcmpType::SourceQuench => String::from("Source Quench"), - IcmpType::RedirectMessage => String::from("Redirect Message"), - IcmpType::EchoRequest => String::from("Echo Request"), - IcmpType::RouterAdvertisement => String::from("Router Advertisement"), - IcmpType::RouterSolicitation => String::from("Router Solicitation"), - IcmpType::TimeExceeded => String::from("Time Exceeded"), - IcmpType::ParameterProblem => String::from("Parameter Problem"), - IcmpType::TimestampRequest => String::from("Timestamp Request"), - IcmpType::TimestampReply => String::from("Timestamp Reply"), - IcmpType::InformationRequest => String::from("Information Request"), - IcmpType::InformationReply => String::from("Information Reply"), - IcmpType::AddressMaskRequest => String::from("Address Mask Request"), - IcmpType::AddressMaskReply => String::from("Address Mask Reply"), - IcmpType::Traceroute => String::from("Traceroute"), - IcmpType::DatagramConversionError => String::from("Datagram Conversion Error"), - IcmpType::MobileHostRedirect => String::from("Mobile Host Redirect"), - IcmpType::IPv6WhereAreYou => String::from("IPv6 Where Are You"), - IcmpType::IPv6IAmHere => String::from("IPv6 I Am Here"), - IcmpType::MobileRegistrationRequest => String::from("Mobile Registration Request"), - IcmpType::MobileRegistrationReply => String::from("Mobile Registration Reply"), - IcmpType::DomainNameRequest => String::from("Domain Name Request"), - IcmpType::DomainNameReply => String::from("Domain Name Reply"), - IcmpType::SKIP => String::from("SKIP"), - IcmpType::Photuris => String::from("Photuris"), - IcmpType::Unknown(n) => format!("Unknown ({})", n), + IcmpType::EchoReply => "Echo Reply", + IcmpType::DestinationUnreachable => "Destination Unreachable", + IcmpType::SourceQuench => "Source Quench", + IcmpType::RedirectMessage => "Redirect Message", + IcmpType::EchoRequest => "Echo Request", + IcmpType::RouterAdvertisement => "Router Advertisement", + IcmpType::RouterSolicitation => "Router Solicitation", + IcmpType::TimeExceeded => "Time Exceeded", + IcmpType::ParameterProblem => "Parameter Problem", + IcmpType::TimestampRequest => "Timestamp Request", + IcmpType::TimestampReply => "Timestamp Reply", + IcmpType::InformationRequest => "Information Request", + IcmpType::InformationReply => "Information Reply", + IcmpType::AddressMaskRequest => "Address Mask Request", + IcmpType::AddressMaskReply => "Address Mask Reply", + IcmpType::Traceroute => "Traceroute", + IcmpType::DatagramConversionError => "Datagram Conversion Error", + IcmpType::MobileHostRedirect => "Mobile Host Redirect", + IcmpType::IPv6WhereAreYou => "IPv6 Where Are You", + IcmpType::IPv6IAmHere => "IPv6 I Am Here", + IcmpType::MobileRegistrationRequest => "Mobile Registration Request", + IcmpType::MobileRegistrationReply => "Mobile Registration Reply", + IcmpType::DomainNameRequest => "Domain Name Request", + IcmpType::DomainNameReply => "Domain Name Reply", + IcmpType::SKIP => "SKIP", + IcmpType::Photuris => "Photuris", + IcmpType::Unknown(_) => "Unknown", } } -} - -impl PrimitiveValues for IcmpType { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { + pub fn value(&self) -> u8 { match *self { - IcmpType::EchoReply => (0,), - IcmpType::DestinationUnreachable => (3,), - IcmpType::SourceQuench => (4,), - IcmpType::RedirectMessage => (5,), - IcmpType::EchoRequest => (8,), - IcmpType::RouterAdvertisement => (9,), - IcmpType::RouterSolicitation => (10,), - IcmpType::TimeExceeded => (11,), - IcmpType::ParameterProblem => (12,), - IcmpType::TimestampRequest => (13,), - IcmpType::TimestampReply => (14,), - IcmpType::InformationRequest => (15,), - IcmpType::InformationReply => (16,), - IcmpType::AddressMaskRequest => (17,), - IcmpType::AddressMaskReply => (18,), - IcmpType::Traceroute => (30,), - IcmpType::DatagramConversionError => (31,), - IcmpType::MobileHostRedirect => (32,), - IcmpType::IPv6WhereAreYou => (33,), - IcmpType::IPv6IAmHere => (34,), - IcmpType::MobileRegistrationRequest => (35,), - IcmpType::MobileRegistrationReply => (36,), - IcmpType::DomainNameRequest => (37,), - IcmpType::DomainNameReply => (38,), - IcmpType::SKIP => (39,), - IcmpType::Photuris => (40,), - IcmpType::Unknown(n) => (n,), + IcmpType::EchoReply => 0, + IcmpType::DestinationUnreachable => 3, + IcmpType::SourceQuench => 4, + IcmpType::RedirectMessage => 5, + IcmpType::EchoRequest => 8, + IcmpType::RouterAdvertisement => 9, + IcmpType::RouterSolicitation => 10, + IcmpType::TimeExceeded => 11, + IcmpType::ParameterProblem => 12, + IcmpType::TimestampRequest => 13, + IcmpType::TimestampReply => 14, + IcmpType::InformationRequest => 15, + IcmpType::InformationReply => 16, + IcmpType::AddressMaskRequest => 17, + IcmpType::AddressMaskReply => 18, + IcmpType::Traceroute => 30, + IcmpType::DatagramConversionError => 31, + IcmpType::MobileHostRedirect => 32, + IcmpType::IPv6WhereAreYou => 33, + IcmpType::IPv6IAmHere => 34, + IcmpType::MobileRegistrationRequest => 35, + IcmpType::MobileRegistrationReply => 36, + IcmpType::DomainNameRequest => 37, + IcmpType::DomainNameReply => 38, + IcmpType::SKIP => 39, + IcmpType::Photuris => 40, + IcmpType::Unknown(n) => n, } } } @@ -199,94 +158,104 @@ impl IcmpCode { pub fn new(val: u8) -> IcmpCode { IcmpCode(val) } -} - -impl PrimitiveValues for IcmpCode { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - (self.0,) + pub fn value(&self) -> u8 { + self.0 } } -/// Represents a generic ICMP packet. -#[packet] -pub struct Icmp { - #[construct_with(u8)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct IcmpHeader { pub icmp_type: IcmpType, - #[construct_with(u8)] pub icmp_code: IcmpCode, - pub checksum: u16be, - // theoretically, the header is 64 bytes long, but since the "Rest Of Header" part depends on - // the ICMP type and ICMP code, we consider it's part of the payload. - // rest_of_header: u32be, - #[payload] - pub payload: Vec, + pub checksum: u16, } -/// Calculates a checksum of an ICMP packet. -pub fn checksum(packet: &IcmpPacket) -> u16be { - use crate::util; - use crate::Packet; - - util::checksum(packet.packet(), 1) +/// ICMP packet representation +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct IcmpPacket { + pub header: IcmpHeader, + pub payload: Bytes, } -#[cfg(test)] -mod checksum_tests { - use super::*; - use alloc::vec; +impl Packet for IcmpPacket { + type Header = IcmpHeader; - #[test] - fn checksum_zeros() { - let mut data = vec![0u8; 8]; - let expected = 65535; - let mut pkg = MutableIcmpPacket::new(&mut data[..]).unwrap(); - assert_eq!(checksum(&pkg.to_immutable()), expected); - pkg.set_checksum(123); - assert_eq!(checksum(&pkg.to_immutable()), expected); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < ICMPV4_HEADER_LEN { + return None; + } + let icmp_type = IcmpType::new(bytes[0]); + let icmp_code = IcmpCode::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let payload = Bytes::copy_from_slice(&bytes[ICMP_COMMON_HEADER_LEN..]); + Some(IcmpPacket { + header: IcmpHeader { + icmp_type, + icmp_code, + checksum, + }, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) } - #[test] - fn checksum_nonzero() { - let mut data = vec![255u8; 8]; - let expected = 0; - let mut pkg = MutableIcmpPacket::new(&mut data[..]).unwrap(); - assert_eq!(checksum(&pkg.to_immutable()), expected); - pkg.set_checksum(0); - assert_eq!(checksum(&pkg.to_immutable()), expected); + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(ICMP_COMMON_HEADER_LEN + self.payload.len()); + buf.put_u8(self.header.icmp_type.value()); + buf.put_u8(self.header.icmp_code.value()); + buf.put_u16(self.header.checksum); + buf.extend_from_slice(&self.payload); + buf.freeze() } - #[test] - fn checksum_odd_bytes() { - let mut data = vec![191u8; 7]; - let expected = 49535; - let pkg = IcmpPacket::new(&mut data[..]).unwrap(); - assert_eq!(checksum(&pkg), expected); + fn header(&self) -> Bytes { + self.to_bytes().slice(..self.header_len()) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + ICMP_COMMON_HEADER_LEN + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) } + } -pub mod echo_reply { - //! abstraction for ICMP "echo reply" packets. - //! - //! ```text - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Type | Code | Checksum | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Identifier | Sequence Number | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Data ... - //! +-+-+-+-+- - //! ``` - - use crate::icmp::{IcmpCode, IcmpType}; - use crate::PrimitiveValues; - - use alloc::vec::Vec; - - use nex_macro::packet; - use nex_macro_helper::types::*; +impl IcmpPacket { + pub fn with_computed_checksum(&self) -> Self { + let mut pkt = self.clone(); + pkt.header.checksum = checksum(&pkt).into(); + pkt + } +} - /// Represent the "identifier" field of the ICMP echo replay header. +/// Calculates a checksum of an ICMP packet. +pub fn checksum(packet: &IcmpPacket) -> u16be { + use crate::util; + util::checksum(&packet.to_bytes(), 1) +} + +pub mod echo_request { + use bytes::Bytes; + + use crate::icmp::{IcmpHeader, IcmpPacket, IcmpType}; + + /// Represents the identifier field. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Identifier(pub u16); @@ -295,16 +264,12 @@ pub mod echo_reply { pub fn new(val: u16) -> Identifier { Identifier(val) } - } - - impl PrimitiveValues for Identifier { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + pub fn value(&self) -> u16 { + self.0 } } - /// Represent the "sequence number" field of the ICMP echo replay header. + /// Represents the sequence number field. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SequenceNumber(pub u16); @@ -313,16 +278,12 @@ pub mod echo_reply { pub fn new(val: u16) -> SequenceNumber { SequenceNumber(val) } - } - - impl PrimitiveValues for SequenceNumber { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + pub fn value(&self) -> u16 { + self.0 } } - /// Enumeration of available ICMP codes for ICMP echo replay packets. There is actually only + /// Enumeration of available ICMP codes for "echo reply" ICMP packets. There is actually only /// one, since the only valid ICMP code is 0. #[allow(non_snake_case)] #[allow(non_upper_case_globals)] @@ -332,43 +293,43 @@ pub mod echo_reply { pub const NoCode: IcmpCode = IcmpCode(0); } - /// Represents an ICMP echo reply packet. - #[packet] - pub struct EchoReply { - #[construct_with(u8)] - pub icmp_type: IcmpType, - #[construct_with(u8)] - pub icmp_code: IcmpCode, - pub checksum: u16be, - pub identifier: u16be, - pub sequence_number: u16be, - #[payload] - pub payload: Vec, + /// Represents an "echo request" ICMP packet. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct EchoRequestPacket { + pub header: IcmpHeader, + pub identifier: u16, + pub sequence_number: u16, + pub payload: Bytes, + } + + impl TryFrom for EchoRequestPacket { + type Error = &'static str; + + fn try_from(pkt: IcmpPacket) -> Result { + if pkt.header.icmp_type != IcmpType::EchoRequest { + return Err("Not an Echo Request"); + } + if pkt.payload.len() < 4 { + return Err("Payload too short for Echo Request"); + } + + Ok(Self { + header: pkt.header, + identifier: u16::from_be_bytes([pkt.payload[0], pkt.payload[1]]), + sequence_number: u16::from_be_bytes([pkt.payload[2], pkt.payload[3]]), + payload: pkt.payload.slice(4..), + }) + } } + } -pub mod echo_request { - //! abstraction for "echo request" ICMP packets. - //! - //! ```text - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Type | Code | Checksum | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Identifier | Sequence Number | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Data ... - //! +-+-+-+-+- - //! ``` - - use crate::icmp::{IcmpCode, IcmpType}; - use crate::PrimitiveValues; - - use alloc::vec::Vec; - - use nex_macro::packet; - use nex_macro_helper::types::*; +pub mod echo_reply { + use bytes::Bytes; - /// Represents the identifier field. + use crate::icmp::{IcmpHeader, IcmpPacket, IcmpType}; + + /// Represent the "identifier" field of the ICMP echo replay header. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Identifier(pub u16); @@ -377,16 +338,12 @@ pub mod echo_request { pub fn new(val: u16) -> Identifier { Identifier(val) } - } - - impl PrimitiveValues for Identifier { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + pub fn value(&self) -> u16 { + self.0 } } - /// Represents the sequence number field. + /// Represent the "sequence number" field of the ICMP echo replay header. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SequenceNumber(pub u16); @@ -395,16 +352,12 @@ pub mod echo_request { pub fn new(val: u16) -> SequenceNumber { SequenceNumber(val) } - } - - impl PrimitiveValues for SequenceNumber { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + pub fn value(&self) -> u16 { + self.0 } } - /// Enumeration of available ICMP codes for "echo reply" ICMP packets. There is actually only + /// Enumeration of available ICMP codes for ICMP echo replay packets. There is actually only /// one, since the only valid ICMP code is 0. #[allow(non_snake_case)] #[allow(non_upper_case_globals)] @@ -414,40 +367,41 @@ pub mod echo_request { pub const NoCode: IcmpCode = IcmpCode(0); } - /// Represents an "echo request" ICMP packet. - #[packet] - pub struct EchoRequest { - #[construct_with(u8)] - pub icmp_type: IcmpType, - #[construct_with(u8)] - pub icmp_code: IcmpCode, - pub checksum: u16be, - pub identifier: u16be, - pub sequence_number: u16be, - #[payload] - pub payload: Vec, + /// Represents an ICMP echo reply packet. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct EchoReplyPacket { + pub header: IcmpHeader, + pub identifier: u16, + pub sequence_number: u16, + pub payload: Bytes, } -} -pub mod destination_unreachable { - //! abstraction for "destination unreachable" ICMP packets. - //! - //! ```text - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Type | Code | Checksum | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | unused | Next-Hop MTU | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Internet Header + 64 bits of Original Data Datagram | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! ``` + impl TryFrom for EchoReplyPacket { + type Error = &'static str; + + fn try_from(pkt: IcmpPacket) -> Result { + if pkt.header.icmp_type != IcmpType::EchoReply { + return Err("Not an Echo Reply"); + } + if pkt.payload.len() < 4 { + return Err("Payload too short for Echo Reply"); + } + + Ok(Self { + header: pkt.header, + identifier: u16::from_be_bytes([pkt.payload[0], pkt.payload[1]]).into(), + sequence_number: u16::from_be_bytes([pkt.payload[2], pkt.payload[3]]).into(), + payload: pkt.payload.slice(4..), + }) + } + } - use crate::icmp::{IcmpCode, IcmpType}; +} - use alloc::vec::Vec; +pub mod destination_unreachable { + use bytes::Bytes; - use nex_macro::packet; - use nex_macro_helper::types::*; + use crate::icmp::{IcmpHeader, IcmpPacket, IcmpType}; /// Enumeration of the recognized ICMP codes for "destination unreachable" ICMP packets. #[allow(non_snake_case)] @@ -489,39 +443,40 @@ pub mod destination_unreachable { } /// Represents an "echo request" ICMP packet. - #[packet] - pub struct DestinationUnreachable { - #[construct_with(u8)] - pub icmp_type: IcmpType, - #[construct_with(u8)] - pub icmp_code: IcmpCode, - pub checksum: u16be, - pub unused: u16be, - pub next_hop_mtu: u16be, - #[payload] - pub payload: Vec, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct DestinationUnreachablePacket { + pub header: IcmpHeader, + pub unused: u16, + pub next_hop_mtu: u16, + pub payload: Bytes, } -} -pub mod time_exceeded { - //! abstraction for "time exceeded" ICMP packets. - //! - //! ```text - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Type | Code | Checksum | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | unused | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! | Internet Header + 64 bits of Original Data Datagram | - //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - //! ``` + impl TryFrom for DestinationUnreachablePacket { + type Error = &'static str; + + fn try_from(pkt: IcmpPacket) -> Result { + if pkt.header.icmp_type != IcmpType::DestinationUnreachable { + return Err("Not a Destination Unreachable"); + } + if pkt.payload.len() < 4 { + return Err("Payload too short for Destination Unreachable"); + } + + Ok(Self { + header: pkt.header, + unused: u16::from_be_bytes([pkt.payload[0], pkt.payload[1]]).into(), + next_hop_mtu: u16::from_be_bytes([pkt.payload[2], pkt.payload[3]]).into(), + payload: pkt.payload.slice(4..), + }) + } + } - use crate::icmp::{IcmpCode, IcmpType}; +} - use alloc::vec::Vec; +pub mod time_exceeded { + use bytes::Bytes; - use nex_macro::packet; - use nex_macro_helper::types::*; + use crate::icmp::{IcmpHeader, IcmpPacket, IcmpType}; /// Enumeration of the recognized ICMP codes for "time exceeded" ICMP packets. #[allow(non_snake_case)] @@ -533,17 +488,137 @@ pub mod time_exceeded { /// ICMP code for "fragment reassembly time exceeded" packet. pub const FragmentReasemblyTimeExceeded: IcmpCode = IcmpCode(1); } - /// Represents an "echo request" ICMP packet. - #[packet] - pub struct TimeExceeded { - #[construct_with(u8)] - pub icmp_type: IcmpType, - #[construct_with(u8)] - pub icmp_code: IcmpCode, - pub checksum: u16be, - pub unused: u32be, - #[payload] - pub payload: Vec, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct TimeExceededPacket { + pub header: IcmpHeader, + pub unused: u32, + pub payload: Bytes, + } + + impl TryFrom for TimeExceededPacket { + type Error = &'static str; + + fn try_from(pkt: IcmpPacket) -> Result { + if pkt.header.icmp_type != IcmpType::TimeExceeded { + return Err("Not a Time Exceeded"); + } + if pkt.payload.len() < 4 { + return Err("Payload too short for Time Exceeded"); + } + + Ok(Self { + header: pkt.header, + unused: u32::from_be_bytes([ + pkt.payload[0], + pkt.payload[1], + pkt.payload[2], + pkt.payload[3], + ]) + .into(), + payload: pkt.payload.slice(4..), + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_echo_request_from_bytes() { + let raw_bytes = Bytes::from_static(&[ + 8, 0, 0x3a, 0xbc, // Type = 8 (Echo Request), Code = 0, Checksum = 0x3abc + 0x04, 0xd2, // Identifier = 0x04d2 (1234) + 0x00, 0x2a, // Sequence = 0x002a (42) + b'p', b'i', b'n', b'g', + ]); + + let parsed = IcmpPacket::from_bytes(raw_bytes.clone()).expect("Failed to parse ICMP"); + let echo = echo_request::EchoRequestPacket::try_from(parsed).expect("Failed to downcast"); + + assert_eq!(echo.header.icmp_type, IcmpType::EchoRequest); + assert_eq!(echo.header.icmp_code, IcmpCode(0)); + assert_eq!(echo.header.checksum, 0x3abc); + assert_eq!(echo.identifier, 1234); + assert_eq!(echo.sequence_number, 42); + assert_eq!(echo.payload, Bytes::from_static(b"ping")); + } + + #[test] + fn test_echo_reply_roundtrip() { + let identifier: u16 = 5678; + let sequence: u16 = 99; + let payload = Bytes::from_static(b"pong"); + + let header = IcmpHeader { + icmp_type: IcmpType::EchoReply, + icmp_code: IcmpCode(0), + checksum: 0, + }; + + let mut buf = BytesMut::with_capacity(4 + payload.len()); + buf.put_u16(identifier); + buf.put_u16(sequence); + buf.extend_from_slice(&payload); + + let pkt = IcmpPacket { header, payload: buf.freeze() }.with_computed_checksum(); + let bytes = pkt.to_bytes(); + + let parsed = IcmpPacket::from_bytes(bytes.clone()).expect("Failed to parse ICMP"); + let echo = echo_reply::EchoReplyPacket::try_from(parsed).expect("Failed to downcast"); + + assert_eq!(echo.identifier, identifier); + assert_eq!(echo.sequence_number, sequence); + assert_eq!(echo.payload, payload); + } + + #[test] + fn test_destination_unreachable() { + let unused: u16 = 0; + let mtu: u16 = 1500; + let payload = Bytes::from_static(b"bad ip"); + + let header = IcmpHeader { + icmp_type: IcmpType::DestinationUnreachable, + icmp_code: IcmpCode(3), // Port unreachable + checksum: 0, + }; + + let mut buf = BytesMut::with_capacity(4 + payload.len()); + buf.put_u16(unused); + buf.put_u16(mtu); + buf.extend_from_slice(&payload); + + let pkt = IcmpPacket { header, payload: buf.freeze() }.with_computed_checksum(); + let parsed = IcmpPacket::from_bytes(pkt.to_bytes()).unwrap(); + let unreachable = destination_unreachable::DestinationUnreachablePacket::try_from(parsed).unwrap(); + + assert_eq!(unreachable.next_hop_mtu, mtu); + assert_eq!(unreachable.payload, payload); + } + + #[test] + fn test_time_exceeded() { + let unused: u32 = 0xdeadbeef; + let payload = Bytes::from_static(b"timeout"); + + let header = IcmpHeader { + icmp_type: IcmpType::TimeExceeded, + icmp_code: IcmpCode(0), // TTL exceeded + checksum: 0, + }; + + let mut buf = BytesMut::with_capacity(4 + payload.len()); + buf.put_u32(unused); + buf.extend_from_slice(&payload); + + let pkt = IcmpPacket { header, payload: buf.freeze() }.with_computed_checksum(); + let parsed = IcmpPacket::from_bytes(pkt.to_bytes()).unwrap(); + let exceeded = time_exceeded::TimeExceededPacket::try_from(parsed).unwrap(); + + assert_eq!(exceeded.unused, unused); + assert_eq!(exceeded.payload, payload); } } diff --git a/nex-packet/src/icmpv6.rs b/nex-packet/src/icmpv6.rs index 8920048..1fad281 100644 --- a/nex-packet/src/icmpv6.rs +++ b/nex-packet/src/icmpv6.rs @@ -1,60 +1,22 @@ //! An ICMPv6 packet abstraction. -use crate::ip::IpNextLevelProtocol; -use crate::PrimitiveValues; - -use alloc::vec::Vec; - -use crate::ethernet::ETHERNET_HEADER_LEN; +use crate::{ethernet::ETHERNET_HEADER_LEN, packet::Packet}; use crate::ipv6::IPV6_HEADER_LEN; -use nex_macro::packet; -use nex_macro_helper::types::*; use std::net::Ipv6Addr; +use bytes::Bytes; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// ICMPv6 Header Length. -pub const ICMPV6_HEADER_LEN: usize = echo_request::MutableEchoRequestPacket::minimum_packet_size(); +/// ICMPv6 Common Header Length. +pub const ICMPV6_COMMON_HEADER_LEN: usize = 4; +/// ICMPv6 Header Length. Including the common header (4 bytes) and the type specific header (4 bytes). +pub const ICMPV6_HEADER_LEN: usize = 8; /// ICMPv6 Minimum Packet Length. pub const ICMPV6_PACKET_LEN: usize = ETHERNET_HEADER_LEN + IPV6_HEADER_LEN + ICMPV6_HEADER_LEN; /// ICMPv6 IP Packet Length. pub const ICMPV6_IP_PACKET_LEN: usize = IPV6_HEADER_LEN + ICMPV6_HEADER_LEN; -/// Represents the ICMPv6 header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Icmpv6Header { - pub icmpv6_type: Icmpv6Type, - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, -} - -impl Icmpv6Header { - /// Construct an ICMPv6 header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < ICMPV6_HEADER_LEN { - return Err("Packet is too small for ICMPv6 header".to_string()); - } - match Icmpv6Packet::new(packet) { - Some(icmpv6_packet) => Ok(Icmpv6Header { - icmpv6_type: icmpv6_packet.get_icmpv6_type(), - icmpv6_code: icmpv6_packet.get_icmpv6_code(), - checksum: icmpv6_packet.get_checksum(), - }), - None => Err("Failed to parse ICMPv6 packet".to_string()), - } - } - /// Construct an ICMPv6 header from a Icmpv6Packet. - pub(crate) fn from_packet(icmpv6_packet: &Icmpv6Packet) -> Icmpv6Header { - Icmpv6Header { - icmpv6_type: icmpv6_packet.get_icmpv6_type(), - icmpv6_code: icmpv6_packet.get_icmpv6_code(), - checksum: icmpv6_packet.get_checksum(), - } - } -} - /// Represents the ICMPv6 types. /// #[repr(u8)] @@ -146,51 +108,90 @@ impl Icmpv6Type { n => Icmpv6Type::Unknown(n), } } -} - -impl PrimitiveValues for Icmpv6Type { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - match *self { - Icmpv6Type::DestinationUnreachable => (1,), - Icmpv6Type::PacketTooBig => (2,), - Icmpv6Type::TimeExceeded => (3,), - Icmpv6Type::ParameterProblem => (4,), - Icmpv6Type::EchoRequest => (128,), - Icmpv6Type::EchoReply => (129,), - Icmpv6Type::MulticastListenerQuery => (130,), - Icmpv6Type::MulticastListenerReport => (131,), - Icmpv6Type::MulticastListenerDone => (132,), - Icmpv6Type::RouterSolicitation => (133,), - Icmpv6Type::RouterAdvertisement => (134,), - Icmpv6Type::NeighborSolicitation => (135,), - Icmpv6Type::NeighborAdvertisement => (136,), - Icmpv6Type::RedirectMessage => (137,), - Icmpv6Type::RouterRenumbering => (138,), - Icmpv6Type::NodeInformationQuery => (139,), - Icmpv6Type::NodeInformationResponse => (140,), - Icmpv6Type::InverseNeighborDiscoverySolicitation => (141,), - Icmpv6Type::InverseNeighborDiscoveryAdvertisement => (142,), - Icmpv6Type::Version2MulticastListenerReport => (143,), - Icmpv6Type::HomeAgentAddressDiscoveryRequest => (144,), - Icmpv6Type::HomeAgentAddressDiscoveryReply => (145,), - Icmpv6Type::MobilePrefixSolicitation => (146,), - Icmpv6Type::MobilePrefixAdvertisement => (147,), - Icmpv6Type::CertificationPathSolicitationMessage => (148,), - Icmpv6Type::CertificationPathAdvertisementMessage => (149,), - Icmpv6Type::ExperimentalMobilityProtocols => (150,), - Icmpv6Type::MulticastRouterAdvertisement => (151,), - Icmpv6Type::MulticastRouterSolicitation => (152,), - Icmpv6Type::MulticastRouterTermination => (153,), - Icmpv6Type::FMIPv6Messages => (154,), - Icmpv6Type::RPLControlMessage => (155,), - Icmpv6Type::ILNPv6LocatorUpdateMessage => (156,), - Icmpv6Type::DuplicateAddressRequest => (157,), - Icmpv6Type::DuplicateAddressConfirmation => (158,), - Icmpv6Type::MPLControlMessage => (159,), - Icmpv6Type::ExtendedEchoRequest => (160,), - Icmpv6Type::ExtendedEchoReply => (161,), - Icmpv6Type::Unknown(n) => (n,), + pub fn name(&self) -> &'static str { + match self { + Icmpv6Type::DestinationUnreachable => "Destination Unreachable", + Icmpv6Type::PacketTooBig => "Packet Too Big", + Icmpv6Type::TimeExceeded => "Time Exceeded", + Icmpv6Type::ParameterProblem => "Parameter Problem", + Icmpv6Type::EchoRequest => "Echo Request", + Icmpv6Type::EchoReply => "Echo Reply", + Icmpv6Type::MulticastListenerQuery => "Multicast Listener Query", + Icmpv6Type::MulticastListenerReport => "Multicast Listener Report", + Icmpv6Type::MulticastListenerDone => "Multicast Listener Done", + Icmpv6Type::RouterSolicitation => "Router Solicitation", + Icmpv6Type::RouterAdvertisement => "Router Advertisement", + Icmpv6Type::NeighborSolicitation => "Neighbor Solicitation", + Icmpv6Type::NeighborAdvertisement => "Neighbor Advertisement", + Icmpv6Type::RedirectMessage => "Redirect Message", + Icmpv6Type::RouterRenumbering => "Router Renumbering", + Icmpv6Type::NodeInformationQuery => "Node Information Query", + Icmpv6Type::NodeInformationResponse => "Node Information Response", + Icmpv6Type::InverseNeighborDiscoverySolicitation => "Inverse Neighbor Discovery Solicitation", + Icmpv6Type::InverseNeighborDiscoveryAdvertisement => "Inverse Neighbor Discovery Advertisement", + Icmpv6Type::Version2MulticastListenerReport => "Version 2 Multicast Listener Report", + Icmpv6Type::HomeAgentAddressDiscoveryRequest => "Home Agent Address Discovery Request", + Icmpv6Type::HomeAgentAddressDiscoveryReply => "Home Agent Address Discovery Reply", + Icmpv6Type::MobilePrefixSolicitation => "Mobile Prefix Solicitation", + Icmpv6Type::MobilePrefixAdvertisement => "Mobile Prefix Advertisement", + Icmpv6Type::CertificationPathSolicitationMessage => "Certification Path Solicitation Message", + Icmpv6Type::CertificationPathAdvertisementMessage => "Certification Path Advertisement Message", + Icmpv6Type::ExperimentalMobilityProtocols => "Experimental Mobility Protocols", + Icmpv6Type::MulticastRouterAdvertisement => "Multicast Router Advertisement", + Icmpv6Type::MulticastRouterSolicitation => "Multicast Router Solicitation", + Icmpv6Type::MulticastRouterTermination => "Multicast Router Termination", + Icmpv6Type::FMIPv6Messages => "FMIPv6 Messages", + Icmpv6Type::RPLControlMessage => "RPL Control Message", + Icmpv6Type::ILNPv6LocatorUpdateMessage => "ILNPv6 Locator Update Message", + Icmpv6Type::DuplicateAddressRequest => "Duplicate Address Request", + Icmpv6Type::DuplicateAddressConfirmation => "Duplicate Address Confirmation", + Icmpv6Type::MPLControlMessage => "MPL Control Message", + Icmpv6Type::ExtendedEchoRequest => "Extended Echo Request", + Icmpv6Type::ExtendedEchoReply => "Extended Echo Reply", + Icmpv6Type::Unknown(_) => "Unknown", + } + } + pub fn value(&self) -> u8 { + match self { + Icmpv6Type::DestinationUnreachable => 1, + Icmpv6Type::PacketTooBig => 2, + Icmpv6Type::TimeExceeded => 3, + Icmpv6Type::ParameterProblem => 4, + Icmpv6Type::EchoRequest => 128, + Icmpv6Type::EchoReply => 129, + Icmpv6Type::MulticastListenerQuery => 130, + Icmpv6Type::MulticastListenerReport => 131, + Icmpv6Type::MulticastListenerDone => 132, + Icmpv6Type::RouterSolicitation => 133, + Icmpv6Type::RouterAdvertisement => 134, + Icmpv6Type::NeighborSolicitation => 135, + Icmpv6Type::NeighborAdvertisement => 136, + Icmpv6Type::RedirectMessage => 137, + Icmpv6Type::RouterRenumbering => 138, + Icmpv6Type::NodeInformationQuery => 139, + Icmpv6Type::NodeInformationResponse => 140, + Icmpv6Type::InverseNeighborDiscoverySolicitation => 141, + Icmpv6Type::InverseNeighborDiscoveryAdvertisement => 142, + Icmpv6Type::Version2MulticastListenerReport => 143, + Icmpv6Type::HomeAgentAddressDiscoveryRequest => 144, + Icmpv6Type::HomeAgentAddressDiscoveryReply => 145, + Icmpv6Type::MobilePrefixSolicitation => 146, + Icmpv6Type::MobilePrefixAdvertisement => 147, + Icmpv6Type::CertificationPathSolicitationMessage => 148, + Icmpv6Type::CertificationPathAdvertisementMessage => 149, + Icmpv6Type::ExperimentalMobilityProtocols => 150, + Icmpv6Type::MulticastRouterAdvertisement => 151, + Icmpv6Type::MulticastRouterSolicitation => 152, + Icmpv6Type::MulticastRouterTermination => 153, + Icmpv6Type::FMIPv6Messages => 154, + Icmpv6Type::RPLControlMessage => 155, + Icmpv6Type::ILNPv6LocatorUpdateMessage => 156, + Icmpv6Type::DuplicateAddressRequest => 157, + Icmpv6Type::DuplicateAddressConfirmation => 158, + Icmpv6Type::MPLControlMessage => 159, + Icmpv6Type::ExtendedEchoRequest => 160, + Icmpv6Type::ExtendedEchoReply => 161, + Icmpv6Type::Unknown(n) => *n, } } } @@ -205,64 +206,105 @@ impl Icmpv6Code { pub fn new(val: u8) -> Icmpv6Code { Icmpv6Code(val) } -} - -impl PrimitiveValues for Icmpv6Code { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - (self.0,) + /// Get the value of the `Icmpv6Code`. + pub fn value(&self) -> u8 { + self.0 } } -/// Represents a generic ICMPv6 packet [RFC 4443 § 2.1] -/// -/// ```text -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | Type | Code | Checksum | -/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -/// | | -/// + Message Body + -/// | | -/// ``` -/// -/// [RFC 4443 § 2.1]: https://tools.ietf.org/html/rfc4443#section-2.1 -#[packet] -pub struct Icmpv6 { - #[construct_with(u8)] +/// Represents the ICMPv6 header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Icmpv6Header { pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, - #[payload] - pub payload: Vec, + pub checksum: u16, +} + +/// ICMP packet representation +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Icmpv6Packet { + pub header: Icmpv6Header, + pub payload: Bytes, +} + +impl Packet for Icmpv6Packet { + type Header = Icmpv6Header; + + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < ICMPV6_HEADER_LEN { + return None; + } + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let header = Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }; + let payload = Bytes::copy_from_slice(&bytes[ICMPV6_COMMON_HEADER_LEN..]); + Some(Icmpv6Packet { header, payload }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(ICMPV6_COMMON_HEADER_LEN + self.payload.len()); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.payload); + Bytes::from(bytes) + } + fn header(&self) -> Bytes { + self.to_bytes().slice(..self.header_len()) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + ICMPV6_COMMON_HEADER_LEN + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) + } } /// Calculates a checksum of an ICMPv6 packet. -pub fn checksum(packet: &Icmpv6Packet, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be { +pub fn checksum(packet: &Icmpv6Packet, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16 { use crate::util; - use crate::Packet; - util::ipv6_checksum( - packet.packet(), - 1, + &packet.to_bytes(), + 1, // skip the checksum field &[], source, destination, - IpNextLevelProtocol::Icmpv6, + crate::ip::IpNextProtocol::Icmpv6, ) } #[cfg(test)] mod checksum_tests { use super::*; - use alloc::vec; #[test] fn checksum_echo_request() { // The equivalent of your typical ping -6 ::1%lo let lo = &Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); - let mut data = vec![ - 0x80, // Icmpv6 Type + let data = Bytes::from_static(&[ + 0x80, // Icmpv6 Type (Echo Request) 0x00, // Code 0xff, 0xff, // Checksum 0x00, 0x00, // Id @@ -272,13 +314,17 @@ mod checksum_tests { 0x77, 0x6f, 0x75, 0x6e, 0x64, 0x20, 0x20, 0x74, 0x69, 0x73, 0x20, 0x62, 0x75, 0x74, 0x20, 0x61, 0x20, 0x73, 0x63, 0x72, 0x61, 0x74, 0x63, 0x68, 0x20, 0x20, 0x6b, 0x6e, 0x69, 0x67, 0x68, 0x74, 0x73, 0x20, 0x6f, 0x66, 0x20, 0x6e, 0x69, 0x20, 0x20, 0x20, - ]; - let mut pkg = MutableIcmpv6Packet::new(&mut data[..]).unwrap(); - assert_eq!(checksum(&pkg.to_immutable(), lo, lo), 0x1d2e); - - // Check - pkg.set_icmpv6_type(Icmpv6Type::new(0x81)); - assert_eq!(checksum(&pkg.to_immutable(), lo, lo), 0x1c2e); + ]); + let mut pkg = Icmpv6Packet::from_bytes(data.clone()).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::EchoRequest); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code::new(0)); + assert_eq!(pkg.header.checksum, 0xffff); + assert_eq!(pkg.to_bytes(), data); + assert_eq!(checksum(&pkg, lo, lo), 0x1d2e); + + // Change type to Echo Reply + pkg.header.icmpv6_type = Icmpv6Type::new(0x81); + assert_eq!(checksum(&pkg, lo, lo), 0x1c2e); } } @@ -287,22 +333,21 @@ pub mod ndp { //! //! [RFC 4861]: https://tools.ietf.org/html/rfc4861 - use crate::icmpv6::{Icmpv6Code, Icmpv6Type}; - use crate::Packet; - use crate::PrimitiveValues; + use bytes::Bytes; + use nex_core::bitfield::{self, u24be, u32be}; - use alloc::vec::Vec; - - use nex_macro::packet; - use nex_macro_helper::types::*; + use crate::icmpv6::{Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type, ICMPV6_HEADER_LEN}; + use crate::packet::Packet; use std::net::Ipv6Addr; /// NDP SOL Packet Length. - pub const NDP_SOL_PACKET_LEN: usize = NeighborSolicitPacket::minimum_packet_size(); + pub const NDP_SOL_PACKET_LEN: usize = 24; /// NDP ADV Packet Length. - pub const NDP_ADV_PACKET_LEN: usize = NeighborAdvertPacket::minimum_packet_size(); + pub const NDP_ADV_PACKET_LEN: usize = 24; + /// NDP REDIRECT Packet Length. + pub const NDP_REDIRECT_PACKET_LEN: usize = 40; /// NDP OPT Packet Length. - pub const NDP_OPT_PACKET_LEN: usize = NdpOptionPacket::minimum_packet_size(); + pub const NDP_OPT_PACKET_LEN: usize = 2; #[allow(non_snake_case)] #[allow(non_upper_case_globals)] @@ -321,12 +366,9 @@ pub mod ndp { pub fn new(value: u8) -> NdpOptionType { NdpOptionType(value) } - } - - impl PrimitiveValues for NdpOptionType { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - (self.0,) + /// Get the value of the `NdpOptionType`. + pub fn value(&self) -> u8 { + self.0 } } @@ -427,27 +469,90 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.6]: https://tools.ietf.org/html/rfc4861#section-4.6 - #[packet] - pub struct NdpOption { - #[construct_with(u8)] + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct NdpOptionPacket { pub option_type: NdpOptionType, - #[construct_with(u8)] pub length: u8, - #[length_fn = "ndp_option_payload_length"] - #[payload] - pub data: Vec, + pub payload: Bytes, } - /// Calculate a length of a `NdpOption`'s payload. - fn ndp_option_payload_length(option: &NdpOptionPacket) -> usize { - let len = option.get_length(); - if len > 0 { - ((len * 8) - 2) as usize - } else { - 0 + impl Packet for NdpOptionPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 2 { + return None; + } + + let option_type = NdpOptionType::new(bytes[0]); + let length = bytes[1]; // unit: 8 bytes + + let total_len = (length as usize) * 8; + if bytes.len() < total_len { + return None; + } + + let data_len = total_len - 2; + let payload = Bytes::copy_from_slice(&bytes[2..2 + data_len]); + + Some(Self { + option_type, + length, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_OPT_PACKET_LEN + self.payload.len()); + bytes.push(self.option_type.value()); + bytes.push(self.length); + bytes.extend_from_slice(&self.payload); + Bytes::from(bytes) + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..NDP_OPT_PACKET_LEN) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + NDP_OPT_PACKET_LEN + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) } } + impl NdpOptionPacket { + /// Calculate the length of the option's payload. + pub fn option_payload_length(&self) -> usize { + //let len = option.get_length(); + let len = self.payload.len(); + if len > 0 { + ((len * 8) - 2) as usize + } else { + 0 + } + } + } + + /// Calculate a length of a `NdpOption`'s payload. + + /// Router Solicitation Message [RFC 4861 § 4.1] /// /// ```text @@ -460,27 +565,136 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.1]: https://tools.ietf.org/html/rfc4861#section-4.1 - #[packet] - pub struct RouterSolicit { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, - pub reserved: u32be, - #[length_fn = "rs_ndp_options_length"] - pub options: Vec, - #[payload] - #[length = "0"] - pub payload: Vec, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct RouterSolicitPacket { + pub header: Icmpv6Header, + pub reserved: u32, + pub options: Vec, + pub payload: Bytes, + } + + impl TryFrom for RouterSolicitPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::RouterSolicitation { + return Err("Not a Router Solicitation packet"); + } + if value.payload.len() < 8 { + return Err("Payload too short for Router Solicitation"); + } + let reserved = u32::from_be_bytes([value.payload[0], value.payload[1], value.payload[2], value.payload[3]]); + let options = value.payload.slice(4..).chunks(8).map(|chunk| { + let option_type = NdpOptionType::new(chunk[0]); + let length = chunk[1]; + let payload = Bytes::from(chunk[2..].to_vec()); + NdpOptionPacket { option_type, length, payload } + }).collect(); + Ok(RouterSolicitPacket { + header: value.header, + reserved, + options, + payload: Bytes::new(), + }) + } + } + + impl Packet for RouterSolicitPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < NDP_SOL_PACKET_LEN { + return None; + } + + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let header = Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }; + let reserved = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + + let mut options = Vec::new(); + let mut i = 8; + while i + 2 <= bytes.len() { + let option_type = NdpOptionType::new(bytes[i]); + let length = bytes[i + 1]; + let option_len = (length as usize) * 8; + + if i + option_len > bytes.len() { + break; + } + + let payload = Bytes::copy_from_slice(&bytes[i + 2..i + option_len]); + options.push(NdpOptionPacket { + option_type, + length, + payload, + }); + i += option_len; + } + + let payload = Bytes::copy_from_slice(&bytes[i..]); + + Some(RouterSolicitPacket { + header, + reserved, + options, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_SOL_PACKET_LEN); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.reserved.to_be_bytes()); + for option in &self.options { + bytes.push(option.option_type.value()); + bytes.push(option.length); + bytes.extend_from_slice(&option.payload); + } + Bytes::from(bytes) + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..ICMPV6_HEADER_LEN) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + ICMPV6_HEADER_LEN + 4 // 4 for reserved + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } } - /// Router Solicit packet calculation for the length of the options. - fn rs_ndp_options_length(pkt: &RouterSolicitPacket) -> usize { - if pkt.packet().len() > 8 { - pkt.packet().len() - 8 - } else { - 0 + impl RouterSolicitPacket { + /// Router Solicit packet calculation for the length of the options. + pub fn options_length(&self) -> usize { + if self.to_bytes().len() > 8 { + self.to_bytes().len() - 8 + } else { + 0 + } } } @@ -513,31 +727,159 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.2]: https://tools.ietf.org/html/rfc4861#section-4.2 - #[packet] - pub struct RouterAdvert { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct RouterAdvertPacket { + pub header: Icmpv6Header, pub hop_limit: u8, pub flags: u8, - pub lifetime: u16be, - pub reachable_time: u32be, - pub retrans_time: u32be, - #[length_fn = "ra_ndp_options_length"] - pub options: Vec, - #[payload] - #[length = "0"] - pub payload: Vec, + pub lifetime: u16, + pub reachable_time: u32, + pub retrans_time: u32, + pub options: Vec, + pub payload: Bytes, } - /// Router Advert packet calculation for the length of the options. - fn ra_ndp_options_length(pkt: &RouterAdvertPacket) -> usize { - if pkt.packet().len() > 16 { - pkt.packet().len() - 16 - } else { - 0 + impl TryFrom for RouterAdvertPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::RouterAdvertisement { + return Err("Not a Router Advertisement packet"); + } + if value.payload.len() < 16 { + return Err("Payload too short for Router Advertisement"); + } + let hop_limit = value.payload[0]; + let flags = value.payload[1]; + let lifetime = u16::from_be_bytes([value.payload[2], value.payload[3]]); + let reachable_time = u32::from_be_bytes([value.payload[4], value.payload[5], value.payload[6], value.payload[7]]); + let retrans_time = u32::from_be_bytes([value.payload[8], value.payload[9], value.payload[10], value.payload[11]]); + let options = value.payload.slice(12..).chunks(8).map(|chunk| { + let option_type = NdpOptionType::new(chunk[0]); + let length = chunk[1]; + let payload = Bytes::from(chunk[2..].to_vec()); + NdpOptionPacket { option_type, length, payload } + }).collect(); + Ok(RouterAdvertPacket { + header: value.header, + hop_limit, + flags, + lifetime, + reachable_time, + retrans_time, + options, + payload: Bytes::new(), + }) + } + } + impl Packet for RouterAdvertPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < NDP_ADV_PACKET_LEN { + return None; + } + + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let header = Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }; + + let hop_limit = bytes[4]; + let flags = bytes[5]; + let lifetime = u16::from_be_bytes([bytes[6], bytes[7]]); + let reachable_time = u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]); + let retrans_time = u32::from_be_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]); + + let mut options = Vec::new(); + let mut i = 16; + while i + 2 <= bytes.len() { + let option_type = NdpOptionType::new(bytes[i]); + let length = bytes[i + 1]; + let option_len = (length as usize) * 8; + + if i + option_len > bytes.len() { + break; + } + + let payload = Bytes::copy_from_slice(&bytes[i + 2..i + option_len]); + options.push(NdpOptionPacket { + option_type, + length, + payload, + }); + i += option_len; + } + + let payload = Bytes::copy_from_slice(&bytes[i..]); + + Some(RouterAdvertPacket { + header, + hop_limit, + flags, + lifetime, + reachable_time, + retrans_time, + options, + payload, + }) + } + + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_ADV_PACKET_LEN); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.push(self.hop_limit); + bytes.push(self.flags); + bytes.extend_from_slice(&self.lifetime.to_be_bytes()); + bytes.extend_from_slice(&self.reachable_time.to_be_bytes()); + bytes.extend_from_slice(&self.retrans_time.to_be_bytes()); + for option in &self.options { + bytes.push(option.option_type.value()); + bytes.push(option.length); + bytes.extend_from_slice(&option.payload); + } + Bytes::from(bytes) + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..ICMPV6_HEADER_LEN + 16) // 16 for the fixed part of the Router Advert + } + fn payload(&self) -> Bytes { + self.payload.clone() + } + fn header_len(&self) -> usize { + ICMPV6_HEADER_LEN + 16 // 16 for the fixed part of the Router Advert + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + + impl RouterAdvertPacket { + /// Router Advert packet calculation for the length of the options. + pub fn options_length(&self) -> usize { + if self.to_bytes().len() > 16 { + self.to_bytes().len() - 16 + } else { + 0 + } } } @@ -562,29 +904,161 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.3]: https://tools.ietf.org/html/rfc4861#section-4.3 - #[packet] - pub struct NeighborSolicit { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, - pub reserved: u32be, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct NeighborSolicitPacket { + pub header: Icmpv6Header, + pub reserved: u32, pub target_addr: Ipv6Addr, - #[length_fn = "ns_ndp_options_length"] - pub options: Vec, - #[payload] - #[length = "0"] - pub payload: Vec, + pub options: Vec, + pub payload: Bytes, } - /// Neighbor Solicit packet calculation for the length of the options. - fn ns_ndp_options_length(pkt: &NeighborSolicitPacket) -> usize { - if pkt.packet().len() > 24 { - pkt.packet().len() - 24 - } else { - 0 + impl TryFrom for NeighborSolicitPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::NeighborSolicitation { + return Err("Not a Neighbor Solicitation packet"); + } + if value.payload.len() < 24 { + return Err("Payload too short for Neighbor Solicitation"); + } + let reserved = u32::from_be_bytes([value.payload[0], value.payload[1], value.payload[2], value.payload[3]]); + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([value.payload[4], value.payload[5]]), + u16::from_be_bytes([value.payload[6], value.payload[7]]), + u16::from_be_bytes([value.payload[8], value.payload[9]]), + u16::from_be_bytes([value.payload[10], value.payload[11]]), + u16::from_be_bytes([value.payload[12], value.payload[13]]), + u16::from_be_bytes([value.payload[14], value.payload[15]]), + u16::from_be_bytes([value.payload[16], value.payload[17]]), + u16::from_be_bytes([value.payload[18], value.payload[19]]), + ); + let options = value.payload.slice(20..).chunks(8).map(|chunk| { + let option_type = NdpOptionType::new(chunk[0]); + let length = chunk[1]; + let payload: Bytes = Bytes::from(chunk[2..].to_vec()); + NdpOptionPacket { option_type, length, payload } + }).collect(); + Ok(NeighborSolicitPacket { + header: value.header, + reserved, + target_addr, + options, + payload: Bytes::new(), + }) + } + } + + impl Packet for NeighborSolicitPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 24 { + return None; + } + + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let reserved = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([bytes[8], bytes[9]]), + u16::from_be_bytes([bytes[10], bytes[11]]), + u16::from_be_bytes([bytes[12], bytes[13]]), + u16::from_be_bytes([bytes[14], bytes[15]]), + u16::from_be_bytes([bytes[16], bytes[17]]), + u16::from_be_bytes([bytes[18], bytes[19]]), + u16::from_be_bytes([bytes[20], bytes[21]]), + u16::from_be_bytes([bytes[22], bytes[23]]), + ); + + let mut options = Vec::new(); + let mut i = 24; + while i + 2 <= bytes.len() { + let option_type = NdpOptionType::new(bytes[i]); + let length = bytes[i + 1]; + let option_len = (length as usize) * 8; + + if option_len < 2 || i + option_len > bytes.len() { + break; + } + + let payload = Bytes::copy_from_slice(&bytes[i + 2..i + option_len]); + options.push(NdpOptionPacket { + option_type, + length, + payload, + }); + + i += option_len; + } + + let payload = Bytes::copy_from_slice(&bytes[i..]); + + Some(NeighborSolicitPacket { + header: Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }, + reserved, + target_addr, + options, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_SOL_PACKET_LEN); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.reserved.to_be_bytes()); + for (_, segment) in self.target_addr.segments().iter().enumerate() { + bytes.extend_from_slice(&segment.to_be_bytes()); + } + for option in &self.options { + bytes.push(option.option_type.value()); + bytes.push(option.length); + bytes.extend_from_slice(&option.payload); + } + Bytes::from(bytes) + } + fn header(&self) -> Bytes { + self.to_bytes().slice(..ICMPV6_HEADER_LEN + 24) // 24 for the fixed part of the Neighbor Solicit + } + fn payload(&self) -> Bytes { + self.payload.clone() + } + fn header_len(&self) -> usize { + ICMPV6_HEADER_LEN + 24 // 24 for the fixed part of the Neighbor Solicit + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + + impl NeighborSolicitPacket { + /// Neighbor Solicit packet calculation for the length of the options. + pub fn options_length(&self) -> usize { + // Calculate the length of the options in the Neighbor Solicitation packet. + if self.to_bytes().len() > 24 { + self.to_bytes().len() - 24 + } else { + 0 + } } } @@ -623,30 +1097,178 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.4]: https://tools.ietf.org/html/rfc4861#section-4.4 - #[packet] - pub struct NeighborAdvert { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct NeighborAdvertPacket { + pub header: Icmpv6Header, pub flags: u8, pub reserved: u24be, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] pub target_addr: Ipv6Addr, - #[length_fn = "na_ndp_options_length"] - pub options: Vec, - #[payload] - #[length = "0"] - pub payload: Vec, + pub options: Vec, + pub payload: Bytes, } - /// Neighbor Advert packet calculation for the length of the options. - fn na_ndp_options_length(pkt: &NeighborAdvertPacket) -> usize { - if pkt.packet().len() > 24 { - pkt.packet().len() - 24 - } else { - 0 + impl TryFrom for NeighborAdvertPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::NeighborAdvertisement { + return Err("Not a Neighbor Advert packet"); + } + // The fixed part of a Neighbor Advertisement message is 20 bytes: + // 1 byte for flags, 3 bytes reserved, and 16 bytes for the target address. + // See RFC 4861 Section 4.4. + // Some packets may not include any options, so 20 bytes is the minimum length. + if value.payload.len() < 20 { + return Err("Payload too short for Neighbor Advert"); + } + let flags = value.payload[0]; + let reserved = bitfield::utils::u24be_from_bytes([value.payload[1], value.payload[2], value.payload[3]]); + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([value.payload[4], value.payload[5]]), + u16::from_be_bytes([value.payload[6], value.payload[7]]), + u16::from_be_bytes([value.payload[8], value.payload[9]]), + u16::from_be_bytes([value.payload[10], value.payload[11]]), + u16::from_be_bytes([value.payload[12], value.payload[13]]), + u16::from_be_bytes([value.payload[14], value.payload[15]]), + u16::from_be_bytes([value.payload[16], value.payload[17]]), + u16::from_be_bytes([value.payload[18], value.payload[19]]), + ); + let options = value.payload.slice(20..).chunks(8).map(|chunk| { + let option_type = NdpOptionType::new(chunk[0]); + let length = chunk[1]; + let payload = Bytes::from(chunk[2..].to_vec()); + NdpOptionPacket { option_type, length, payload } + }).collect(); + Ok(NeighborAdvertPacket { + header: value.header, + flags, + reserved, + target_addr, + options, + payload: Bytes::new(), + }) + } + } + + impl Packet for NeighborAdvertPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 24 { + return None; + } + + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let header = Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }; + + let flags = bytes[4]; + let reserved = bitfield::utils::u24be_from_bytes([bytes[5], bytes[6], bytes[7]]); + + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([bytes[8], bytes[9]]), + u16::from_be_bytes([bytes[10], bytes[11]]), + u16::from_be_bytes([bytes[12], bytes[13]]), + u16::from_be_bytes([bytes[14], bytes[15]]), + u16::from_be_bytes([bytes[16], bytes[17]]), + u16::from_be_bytes([bytes[18], bytes[19]]), + u16::from_be_bytes([bytes[20], bytes[21]]), + u16::from_be_bytes([bytes[22], bytes[23]]), + ); + + let mut options = Vec::new(); + let mut i = 24; + while i + 2 <= bytes.len() { + let option_type = NdpOptionType::new(bytes[i]); + let length = bytes[i + 1]; + let option_len = (length as usize) * 8; + + if option_len < 2 || i + option_len > bytes.len() { + break; + } + + let payload = Bytes::copy_from_slice(&bytes[i + 2..i + option_len]); + options.push(NdpOptionPacket { + option_type, + length, + payload, + }); + + i += option_len; + } + + let payload = Bytes::copy_from_slice(&bytes[i..]); + + Some(NeighborAdvertPacket { + header, + flags, + reserved, + target_addr, + options, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_ADV_PACKET_LEN); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + + // Combine flags and reserved (flags in the most significant 8 bits) + let flags_reserved = (self.flags as u32) << 24 | (self.reserved & 0x00FF_FFFF); + bytes.extend_from_slice(&flags_reserved.to_be_bytes()); + + for segment in self.target_addr.segments().iter() { + bytes.extend_from_slice(&segment.to_be_bytes()); + } + + for option in &self.options { + bytes.push(option.option_type.value()); + bytes.push(option.length); + bytes.extend_from_slice(&option.payload); + } + + Bytes::from(bytes) + } + fn header(&self) -> Bytes { + self.to_bytes().slice(..ICMPV6_HEADER_LEN + 24) // 24 for the fixed part of the Neighbor Advert + } + fn payload(&self) -> Bytes { + self.payload.clone() + } + fn header_len(&self) -> usize { + ICMPV6_HEADER_LEN + 24 // 24 for the fixed part of the Neighbor Advert + } + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + + impl NeighborAdvertPacket { + /// Neighbor Advert packet calculation for the length of the options. + pub fn options_length(&self) -> usize { + // Calculate the length of the options in the Neighbor Advert packet. + if self.to_bytes().len() > 24 { + self.to_bytes().len() - 24 + } else { + 0 + } } } @@ -679,31 +1301,190 @@ pub mod ndp { /// ``` /// /// [RFC 4861 § 4.5]: https://tools.ietf.org/html/rfc4861#section-4.5 - #[packet] - pub struct Redirect { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct RedirectPacket { + pub header: Icmpv6Header, pub reserved: u32be, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] pub target_addr: Ipv6Addr, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] pub dest_addr: Ipv6Addr, - #[length_fn = "redirect_options_length"] - pub options: Vec, - #[payload] - #[length = "0"] - pub payload: Vec, + pub options: Vec, + pub payload: Bytes, } - /// Redirect packet calculation for the length of the options. - fn redirect_options_length(pkt: &RedirectPacket) -> usize { - if pkt.packet().len() > 40 { - pkt.packet().len() - 40 - } else { - 0 + impl TryFrom for RedirectPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::RedirectMessage { + return Err("Not a Redirect packet"); + } + if value.payload.len() < 40 { + return Err("Payload too short for Redirect"); + } + let reserved = u32be::from_be_bytes([value.payload[0], value.payload[1], value.payload[2], value.payload[3]]); + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([value.payload[4], value.payload[5]]), + u16::from_be_bytes([value.payload[6], value.payload[7]]), + u16::from_be_bytes([value.payload[8], value.payload[9]]), + u16::from_be_bytes([value.payload[10], value.payload[11]]), + u16::from_be_bytes([value.payload[12], value.payload[13]]), + u16::from_be_bytes([value.payload[14], value.payload[15]]), + u16::from_be_bytes([value.payload[16], value.payload[17]]), + u16::from_be_bytes([value.payload[18], value.payload[19]]), + ); + let dest_addr = Ipv6Addr::new( + u16::from_be_bytes([value.payload[20], value.payload[21]]), + u16::from_be_bytes([value.payload[22], value.payload[23]]), + u16::from_be_bytes([value.payload[24], value.payload[25]]), + u16::from_be_bytes([value.payload[26], value.payload[27]]), + u16::from_be_bytes([value.payload[28], value.payload[29]]), + u16::from_be_bytes([value.payload[30], value.payload[31]]), + u16::from_be_bytes([value.payload[32], value.payload[33]]), + u16::from_be_bytes([value.payload[34], value.payload[35]]), + ); + let options = value.payload.slice(36..).chunks(8).map(|chunk| { + let option_type = NdpOptionType::new(chunk[0]); + let length = chunk[1]; + let payload = Bytes::from(chunk[2..].to_vec()); + NdpOptionPacket { option_type, length, payload } + }).collect(); + Ok(RedirectPacket { + header: value.header, + reserved, + target_addr, + dest_addr, + options, + payload: Bytes::new(), + }) + } + } + + impl Packet for RedirectPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 40 { + return None; + } + + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let header = Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }; + + let reserved = u32be::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + + let target_addr = Ipv6Addr::new( + u16::from_be_bytes([bytes[8], bytes[9]]), + u16::from_be_bytes([bytes[10], bytes[11]]), + u16::from_be_bytes([bytes[12], bytes[13]]), + u16::from_be_bytes([bytes[14], bytes[15]]), + u16::from_be_bytes([bytes[16], bytes[17]]), + u16::from_be_bytes([bytes[18], bytes[19]]), + u16::from_be_bytes([bytes[20], bytes[21]]), + u16::from_be_bytes([bytes[22], bytes[23]]), + ); + + let dest_addr = Ipv6Addr::new( + u16::from_be_bytes([bytes[24], bytes[25]]), + u16::from_be_bytes([bytes[26], bytes[27]]), + u16::from_be_bytes([bytes[28], bytes[29]]), + u16::from_be_bytes([bytes[30], bytes[31]]), + u16::from_be_bytes([bytes[32], bytes[33]]), + u16::from_be_bytes([bytes[34], bytes[35]]), + u16::from_be_bytes([bytes[36], bytes[37]]), + u16::from_be_bytes([bytes[38], bytes[39]]), + ); + + let mut options = Vec::new(); + let mut i = 40; + while i + 2 <= bytes.len() { + let option_type = NdpOptionType::new(bytes[i]); + let length = bytes[i + 1]; + let option_len = (length as usize) * 8; + + if option_len < 2 || i + option_len > bytes.len() { + break; + } + + let payload = Bytes::copy_from_slice(&bytes[i + 2..i + option_len]); + options.push(NdpOptionPacket { + option_type, + length, + payload, + }); + + i += option_len; + } + + let payload = Bytes::copy_from_slice(&bytes[i..]); + + Some(RedirectPacket { + header, + reserved, + target_addr, + dest_addr, + options, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(NDP_REDIRECT_PACKET_LEN); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.reserved.to_be_bytes()); + for (_, segment) in self.target_addr.segments().iter().enumerate() { + bytes.extend_from_slice(&segment.to_be_bytes()); + } + for (_, segment) in self.dest_addr.segments().iter().enumerate() { + bytes.extend_from_slice(&segment.to_be_bytes()); + } + for option in &self.options { + bytes.push(option.option_type.value()); + bytes.push(option.length); + bytes.extend_from_slice(&option.payload); + } + Bytes::from(bytes) + } + fn header(&self) -> Bytes { + self.to_bytes().slice(..ICMPV6_HEADER_LEN + 40) // 40 for the fixed part of the Redirect + } + fn payload(&self) -> Bytes { + self.payload.clone() + } + fn header_len(&self) -> usize { + ICMPV6_HEADER_LEN + 40 // 40 for the fixed part of the Redirect + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + + impl RedirectPacket { + /// Redirect packet calculation for the length of the options. + pub fn options_length(&self) -> usize { + // Calculate the length of the options in the Redirect packet. + if self.to_bytes().len() > 40 { + self.to_bytes().len() - 40 + } else { + 0 + } } } @@ -711,76 +1492,87 @@ pub mod ndp { mod ndp_tests { use super::*; use crate::icmpv6::{Icmpv6Code, Icmpv6Type}; - use alloc::vec; #[test] fn basic_option_parsing() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x02, 0x01, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, // Extra bytes to confuse the parsing 0x00, 0x00, 0x00, - ]; - let pkg = MutableNdpOptionPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_option_type(), NdpOptionTypes::TargetLLAddr); - assert_eq!(pkg.get_length(), 0x01); - assert_eq!(pkg.payload().len(), 6); - assert_eq!(pkg.payload(), &[0x06, 0x05, 0x04, 0x03, 0x02, 0x01]); + ]); + let pkg = NdpOptionPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.option_type, NdpOptionTypes::TargetLLAddr); + assert_eq!(pkg.length, 0x01); + assert_eq!(pkg.payload.len(), 6); + assert_eq!(pkg.payload.as_ref(), &[0x06, 0x05, 0x04, 0x03, 0x02, 0x01]); } #[test] fn basic_rs_parse() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x85, // Type 0x00, // Code 0x00, 0x00, // Checksum 0x00, 0x00, 0x00, 0x00, // Reserved 0x02, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]; + ]); - let pkg = MutableRouterSolicitPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_icmpv6_type(), Icmpv6Type::RouterSolicitation); - assert_eq!(pkg.get_icmpv6_code(), Icmpv6Code(0)); - assert_eq!(pkg.get_checksum(), 0); - assert_eq!(pkg.get_reserved(), 0); - assert_eq!(pkg.get_options().len(), 2); + let pkg = RouterSolicitPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::RouterSolicitation); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(pkg.header.checksum, 0); + assert_eq!(pkg.reserved, 0); + assert_eq!(pkg.options.len(), 2); - let option = &pkg.get_options()[0]; + let option = &pkg.options[0]; assert_eq!(option.option_type, NdpOptionTypes::TargetLLAddr); assert_eq!(option.length, 0x01); - assert_eq!(option.data, &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); - assert_eq!(option.data.len(), 6); + assert_eq!(option.payload.as_ref(), &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + assert_eq!(option.payload.len(), 6); - let option = &pkg.get_options()[1]; + let option = &pkg.options[1]; assert_eq!(option.option_type, NdpOptionTypes::SourceLLAddr); assert_eq!(option.length, 1); - assert_eq!(option.data, &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + assert_eq!(option.payload.as_ref(), &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); } #[test] fn basic_rs_create() { - let ref_packet = vec![ - 0x85, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, - ]; - let mut packet = [0u8; 16]; - let options = vec![NdpOption { + use crate::icmpv6::ndp::{NdpOptionPacket, RouterSolicitPacket}; + + let options = vec![NdpOptionPacket { option_type: NdpOptionTypes::SourceLLAddr, length: 1, - data: vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + payload: Bytes::from_static(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]), }]; - { - let mut rs_packet = MutableRouterSolicitPacket::new(&mut packet[..]).unwrap(); - rs_packet.set_icmpv6_type(Icmpv6Type::RouterSolicitation); - rs_packet.set_icmpv6_code(Icmpv6Code(0)); - rs_packet.set_options(&options[..]); - } - assert_eq!(&ref_packet[..], &packet[..]); + + let packet = RouterSolicitPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::RouterSolicitation, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + reserved: 0, + options, + payload: Bytes::new(), + }; + + let bytes = packet.to_bytes(); + + let expected = Bytes::from_static(&[ + 0x85, 0x00, 0x00, 0x00, // Type, Code, Checksum + 0x00, 0x00, 0x00, 0x00, // Reserved + 0x01, 0x01, // Option Type, Length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Option Data + ]); + + assert_eq!(bytes, expected); } #[test] fn basic_ra_parse() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x86, // Type 0x00, // Code 0x00, 0x00, // Checksum @@ -791,161 +1583,221 @@ pub mod ndp { 0x87, 0x65, 0x43, 0x21, // Retrans 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Source Link-Layer 0x05, 0x01, 0x00, 0x00, 0x57, 0x68, 0x61, 0x74, // MTU - ]; - let pkg = MutableRouterAdvertPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_icmpv6_type(), Icmpv6Type::RouterAdvertisement); - assert_eq!(pkg.get_icmpv6_code(), Icmpv6Code(0)); - assert_eq!(pkg.get_checksum(), 0x00); - assert_eq!(pkg.get_hop_limit(), 0xff); - assert_eq!(pkg.get_flags(), RouterAdvertFlags::ManagedAddressConf); - assert_eq!(pkg.get_lifetime(), 0x900); - assert_eq!(pkg.get_reachable_time(), 0x12345678); - assert_eq!(pkg.get_retrans_time(), 0x87654321); - assert_eq!(pkg.get_options().len(), 2); - - let option = &pkg.get_options()[0]; + ]); + let pkg = RouterAdvertPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::RouterAdvertisement); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(pkg.header.checksum, 0x00); + assert_eq!(pkg.hop_limit, 0xff); + assert_eq!(pkg.flags, RouterAdvertFlags::ManagedAddressConf); + assert_eq!(pkg.lifetime, 0x900); + assert_eq!(pkg.reachable_time, 0x12345678); + assert_eq!(pkg.retrans_time, 0x87654321); + assert_eq!(pkg.options.len(), 2); + + let option = &pkg.options[0]; assert_eq!(option.option_type, NdpOptionTypes::SourceLLAddr); assert_eq!(option.length, 1); - assert_eq!(option.data, &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + assert_eq!(option.payload.as_ref(), &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); - let option = &pkg.get_options()[1]; + let option = &pkg.options[1]; assert_eq!(option.option_type, NdpOptionTypes::MTU); assert_eq!(option.length, 1); - assert_eq!(option.data, &[0x00, 0x00, 0x57, 0x68, 0x61, 0x74]); + assert_eq!(option.payload.as_ref(), &[0x00, 0x00, 0x57, 0x68, 0x61, 0x74]); } #[test] fn basic_ra_create() { - let ref_packet = vec![ - 0x86, 0x00, 0x00, 0x00, 0xff, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]; - let mut packet = [0u8; 24]; - let options = vec![NdpOption { + use crate::icmpv6::ndp::{NdpOptionPacket, RouterAdvertPacket, RouterAdvertFlags}; + + let options = vec![NdpOptionPacket { option_type: NdpOptionTypes::MTU, length: 1, - data: vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + payload: Bytes::from_static(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]), }]; - { - let mut ra_packet = MutableRouterAdvertPacket::new(&mut packet[..]).unwrap(); - ra_packet.set_icmpv6_type(Icmpv6Type::RouterAdvertisement); - ra_packet.set_icmpv6_code(Icmpv6Code(0)); - ra_packet.set_hop_limit(0xff); - ra_packet.set_flags(RouterAdvertFlags::ManagedAddressConf); - ra_packet.set_options(&options[..]); - } - assert_eq!(&ref_packet[..], &packet[..]); + + let packet = RouterAdvertPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::RouterAdvertisement, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + hop_limit: 0xff, + flags: RouterAdvertFlags::ManagedAddressConf, + lifetime: 0, + reachable_time: 0, + retrans_time: 0, + options, + payload: Bytes::new(), + }; + + let bytes = packet.to_bytes(); + let expected = Bytes::from_static(&[ + 0x86, 0x00, 0x00, 0x00, // header + 0xff, 0x80, 0x00, 0x00, // hop limit, flags, lifetime + 0x00, 0x00, 0x00, 0x00, // reachable + 0x00, 0x00, 0x00, 0x00, // retrans + 0x05, 0x01, // option type + len + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // option data + ]); + + assert_eq!(bytes, expected); } #[test] fn basic_ns_parse() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x87, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - ]; - let pkg = MutableNeighborSolicitPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_icmpv6_type(), Icmpv6Type::NeighborSolicitation); - assert_eq!(pkg.get_icmpv6_code(), Icmpv6Code(0)); - assert_eq!(pkg.get_checksum(), 0x00); - assert_eq!(pkg.get_reserved(), 0x00); + ]); + let pkg = NeighborSolicitPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::NeighborSolicitation); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(pkg.header.checksum, 0x00); + assert_eq!(pkg.reserved, 0x00); assert_eq!( - pkg.get_target_addr(), + pkg.target_addr, Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1) ); } #[test] fn basic_ns_create() { - let ref_packet = vec![ - 0x87, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - ]; - let mut packet = [0u8; 24]; - { - let mut ns_packet = MutableNeighborSolicitPacket::new(&mut packet[..]).unwrap(); - ns_packet.set_icmpv6_type(Icmpv6Type::NeighborSolicitation); - ns_packet.set_icmpv6_code(Icmpv6Code(0)); - ns_packet.set_target_addr(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)); - } - assert_eq!(&ref_packet[..], &packet[..]); + use crate::icmpv6::ndp::NeighborSolicitPacket; + + let packet = NeighborSolicitPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::NeighborSolicitation, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + reserved: 0, + target_addr: Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + options: vec![], + payload: Bytes::new(), + }; + + let bytes = packet.to_bytes(); + + let expected = Bytes::from_static(&[ + 0x87, 0x00, 0x00, 0x00, // header + 0x00, 0x00, 0x00, 0x00, // reserved + 0xff, 0x02, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, // target + ]); + + assert_eq!(bytes, expected); } #[test] fn basic_na_parse() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x88, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - ]; - let pkg = MutableNeighborAdvertPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_icmpv6_type(), Icmpv6Type::NeighborAdvertisement); - assert_eq!(pkg.get_icmpv6_code(), Icmpv6Code(0)); - assert_eq!(pkg.get_checksum(), 0x00); - assert_eq!(pkg.get_reserved(), 0x00); - assert_eq!(pkg.get_flags(), 0x80); + ]); + let pkg = NeighborAdvertPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::NeighborAdvertisement); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(pkg.header.checksum, 0x00); + assert_eq!(pkg.reserved, 0x00); + assert_eq!(pkg.flags, 0x80); assert_eq!( - pkg.get_target_addr(), + pkg.target_addr, Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1) ); } #[test] fn basic_na_create() { - let ref_packet = vec![ - 0x88, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - ]; - let mut packet = [0u8; 24]; - { - let mut na_packet = MutableNeighborAdvertPacket::new(&mut packet[..]).unwrap(); - na_packet.set_icmpv6_type(Icmpv6Type::NeighborAdvertisement); - na_packet.set_icmpv6_code(Icmpv6Code(0)); - na_packet.set_target_addr(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)); - na_packet.set_flags(NeighborAdvertFlags::Router); - } - assert_eq!(&ref_packet[..], &packet[..]); + use crate::icmpv6::ndp::{NeighborAdvertPacket, NeighborAdvertFlags}; + + let packet = NeighborAdvertPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::NeighborAdvertisement, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + flags: NeighborAdvertFlags::Router, + reserved: 0, + target_addr: Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + options: vec![], + payload: Bytes::new(), + }; + + let bytes = packet.to_bytes(); + + let expected = Bytes::from_static(&[ + 0x88, 0x00, 0x00, 0x00, // header + 0x80, 0x00, 0x00, 0x00, // flags + reserved + 0xff, 0x02, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + ]); + + assert_eq!(bytes, expected); } #[test] fn basic_redirect_parse() { - let mut data = vec![ + let data = Bytes::from_static(&[ 0x89, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]; - let pkg = MutableRedirectPacket::new(&mut data[..]).unwrap(); - assert_eq!(pkg.get_icmpv6_type(), Icmpv6Type::RedirectMessage); - assert_eq!(pkg.get_icmpv6_code(), Icmpv6Code(0)); - assert_eq!(pkg.get_checksum(), 0x00); - assert_eq!(pkg.get_reserved(), 0x00); + ]); + let pkg = RedirectPacket::from_bytes(data).unwrap(); + assert_eq!(pkg.header.icmpv6_type, Icmpv6Type::RedirectMessage); + assert_eq!(pkg.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(pkg.header.checksum, 0x00); + assert_eq!(pkg.reserved, 0x00); assert_eq!( - pkg.get_target_addr(), + pkg.target_addr, Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1) ); - assert_eq!(pkg.get_dest_addr(), Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + assert_eq!(pkg.dest_addr, Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); } #[test] fn basic_redirect_create() { - let ref_packet = vec![ - 0x89, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ]; - let mut packet = [0u8; 40]; - { - let mut rdr_packet = MutableRedirectPacket::new(&mut packet[..]).unwrap(); - rdr_packet.set_icmpv6_type(Icmpv6Type::RedirectMessage); - rdr_packet.set_icmpv6_code(Icmpv6Code(0)); - rdr_packet.set_target_addr(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)); - rdr_packet.set_dest_addr(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); - } - assert_eq!(&ref_packet[..], &packet[..]); + use crate::icmpv6::ndp::RedirectPacket; + + let packet = RedirectPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::RedirectMessage, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + reserved: 0, + target_addr: Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + dest_addr: Ipv6Addr::UNSPECIFIED, + options: vec![], + payload: Bytes::new(), + }; + + let bytes = packet.to_bytes(); + + let expected = Bytes::from_static(&[ + 0x89, 0x00, 0x00, 0x00, // header + 0x00, 0x00, 0x00, 0x00, // reserved + 0xff, 0x02, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, // target + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, // dest + ]); + + assert_eq!(bytes, expected); } } } -pub mod echo_reply { - //! abstraction for "echo reply" ICMPv6 packets. +pub mod echo_request { + //! abstraction for "echo request" ICMPv6 packets. //! //! ```text //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -957,13 +1809,9 @@ pub mod echo_reply { //! +-+-+-+-+- //! ``` - use crate::icmpv6::{Icmpv6Code, Icmpv6Type}; - use crate::PrimitiveValues; - - use alloc::vec::Vec; + use bytes::Bytes; - use nex_macro::packet; - use nex_macro_helper::types::*; + use crate::{icmpv6::{Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type}, packet::Packet}; /// Represents the identifier field. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -974,12 +1822,9 @@ pub mod echo_reply { pub fn new(val: u16) -> Identifier { Identifier(val) } - } - - impl PrimitiveValues for Identifier { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + /// Get the value of the identifier. + pub fn value(&self) -> u16 { + self.0 } } @@ -992,12 +1837,9 @@ pub mod echo_reply { pub fn new(val: u16) -> SequenceNumber { SequenceNumber(val) } - } - - impl PrimitiveValues for SequenceNumber { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + /// Get the value of the sequence number. + pub fn value(&self) -> u16 { + self.0 } } @@ -1011,23 +1853,102 @@ pub mod echo_reply { pub const NoCode: Icmpv6Code = Icmpv6Code(0); } - /// Represents an "echo reply" ICMPv6 packet. - #[packet] - pub struct EchoReply { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, - pub identifier: u16be, - pub sequence_number: u16be, - #[payload] - pub payload: Vec, + /// Represents an "echo request" ICMPv6 packet. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct EchoRequestPacket { + pub header: Icmpv6Header, + pub identifier: u16, + pub sequence_number: u16, + pub payload: Bytes, + } + + impl TryFrom for EchoRequestPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::EchoRequest { + return Err("Not an Echo Request packet"); + } + if value.payload.len() < 8 { + return Err("Payload too short for Echo Request"); + } + let identifier = u16::from_be_bytes([value.payload[0], value.payload[1]]); + let sequence_number = u16::from_be_bytes([value.payload[2], value.payload[3]]); + Ok(EchoRequestPacket { + header: value.header, + identifier, + sequence_number, + payload: value.payload.slice(4..), + }) + } } + + impl Packet for EchoRequestPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 8 { + return None; + } + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let identifier = u16::from_be_bytes([bytes[4], bytes[5]]); + let sequence_number = u16::from_be_bytes([bytes[6], bytes[7]]); + Some(EchoRequestPacket { + header: Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }, + identifier, + sequence_number, + payload: Bytes::copy_from_slice(&bytes[8..]), + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(8 + self.payload.len()); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.identifier.to_be_bytes()); + bytes.extend_from_slice(&self.sequence_number.to_be_bytes()); + bytes.extend_from_slice(&self.payload); + Bytes::from(bytes) + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..8) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 8 // Header length for echo request + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + } -pub mod echo_request { - //! abstraction for "echo request" ICMPv6 packets. +pub mod echo_reply { + //! abstraction for "echo reply" ICMPv6 packets. //! //! ```text //! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -1038,15 +1959,10 @@ pub mod echo_request { //! | Data ... //! +-+-+-+-+- //! ``` + + use bytes::Bytes; - use crate::icmpv6::{Icmpv6Code, Icmpv6Type}; - use crate::PrimitiveValues; - - use alloc::vec::Vec; - - use nex_macro::packet; - use nex_macro_helper::types::*; - + use crate::{icmpv6::{Icmpv6Code, Icmpv6Header, Icmpv6Packet, Icmpv6Type}, packet::Packet}; /// Represents the identifier field. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Identifier(pub u16); @@ -1056,12 +1972,9 @@ pub mod echo_request { pub fn new(val: u16) -> Identifier { Identifier(val) } - } - - impl PrimitiveValues for Identifier { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + /// Get the value of the identifier. + pub fn value(&self) -> u16 { + self.0 } } @@ -1074,12 +1987,9 @@ pub mod echo_request { pub fn new(val: u16) -> SequenceNumber { SequenceNumber(val) } - } - - impl PrimitiveValues for SequenceNumber { - type T = (u16,); - fn to_primitive_values(&self) -> (u16,) { - (self.0,) + /// Get the value of the sequence number. + pub fn value(&self) -> u16 { + self.0 } } @@ -1093,17 +2003,182 @@ pub mod echo_request { pub const NoCode: Icmpv6Code = Icmpv6Code(0); } - /// Represents an "echo request" ICMPv6 packet. - #[packet] - pub struct EchoRequest { - #[construct_with(u8)] - pub icmpv6_type: Icmpv6Type, - #[construct_with(u8)] - pub icmpv6_code: Icmpv6Code, - pub checksum: u16be, - pub identifier: u16be, - pub sequence_number: u16be, - #[payload] - pub payload: Vec, + /// Represents an "echo reply" ICMPv6 packet. + #[derive(Clone, Debug, PartialEq, Eq)] + pub struct EchoReplyPacket { + pub header: Icmpv6Header, + pub identifier: u16, + pub sequence_number: u16, + pub payload: Bytes, + } + impl TryFrom for EchoReplyPacket { + type Error = &'static str; + + fn try_from(value: Icmpv6Packet) -> Result { + if value.header.icmpv6_type != Icmpv6Type::EchoReply { + return Err("Not an Echo Reply packet"); + } + if value.payload.len() < 8 { + return Err("Payload too short for Echo Reply"); + } + let identifier = u16::from_be_bytes([value.payload[0], value.payload[1]]); + let sequence_number = u16::from_be_bytes([value.payload[2], value.payload[3]]); + Ok(EchoReplyPacket { + header: value.header, + identifier, + sequence_number, + payload: value.payload.slice(4..), + }) + } + } + impl Packet for EchoReplyPacket { + type Header = (); + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < 8 { + return None; + } + let icmpv6_type = Icmpv6Type::new(bytes[0]); + let icmpv6_code = Icmpv6Code::new(bytes[1]); + let checksum = u16::from_be_bytes([bytes[2], bytes[3]]); + let identifier = u16::from_be_bytes([bytes[4], bytes[5]]); + let sequence_number = u16::from_be_bytes([bytes[6], bytes[7]]); + Some(EchoReplyPacket { + header: Icmpv6Header { + icmpv6_type, + icmpv6_code, + checksum, + }, + identifier, + sequence_number, + payload: Bytes::copy_from_slice(&bytes[8..]), + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = Vec::with_capacity(8 + self.payload.len()); + bytes.push(self.header.icmpv6_type.value()); + bytes.push(self.header.icmpv6_code.value()); + bytes.extend_from_slice(&self.header.checksum.to_be_bytes()); + bytes.extend_from_slice(&self.identifier.to_be_bytes()); + bytes.extend_from_slice(&self.sequence_number.to_be_bytes()); + bytes.extend_from_slice(&self.payload); + Bytes::from(bytes) + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..8) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 8 // Header length for echo reply + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } + } + +} + +#[cfg(test)] +mod echo_tests { + use super::*; + use crate::icmpv6::{echo_reply::EchoReplyPacket, echo_request::EchoRequestPacket, Icmpv6Code, Icmpv6Type}; + + #[test] + fn test_echo_request_parse() { + let raw = Bytes::from_static(&[ + 0x80, 0x00, 0xbe, 0xef, // header: type, code, checksum + 0x12, 0x34, // identifier + 0x56, 0x78, // sequence number + b'p', b'i', b'n', b'g', b'!', + ]); + + let parsed = EchoRequestPacket::from_bytes(raw.clone()).expect("Failed to parse Echo Request packet"); + + assert_eq!(parsed.header.icmpv6_type, Icmpv6Type::EchoRequest); + assert_eq!(parsed.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(parsed.header.checksum, 0xbeef); + assert_eq!(parsed.identifier, 0x1234); + assert_eq!(parsed.sequence_number, 0x5678); + assert_eq!(parsed.payload, Bytes::from_static(b"ping!")); + } + + #[test] + fn test_echo_request_create() { + let payload = Bytes::from_static(b"hello"); + let packet = EchoRequestPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::EchoRequest, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + identifier: 0x1234, + sequence_number: 0x5678, + payload: payload.clone(), + }; + let bytes = packet.to_bytes(); + let parsed = EchoRequestPacket::from_bytes(bytes).unwrap(); + + assert_eq!(parsed.identifier, 0x1234); + assert_eq!(parsed.sequence_number, 0x5678); + assert_eq!(parsed.payload, payload); + } + + #[test] + fn test_echo_reply_parse() { + let raw = Bytes::from_static(&[ + 0x81, 0x00, 0x12, 0x34, // header: type, code, checksum + 0xab, 0xcd, // identifier + 0x56, 0x78, // sequence number + b'h', b'e', b'l', b'l', b'o', + ]); + + let parsed = EchoReplyPacket::from_bytes(raw.clone()).expect("Failed to parse Echo Reply packet"); + + assert_eq!(parsed.header.icmpv6_type, Icmpv6Type::EchoReply); + assert_eq!(parsed.header.icmpv6_code, Icmpv6Code(0)); + assert_eq!(parsed.header.checksum, 0x1234); + assert_eq!(parsed.identifier, 0xabcd); + assert_eq!(parsed.sequence_number, 0x5678); + assert_eq!(parsed.payload, Bytes::from_static(b"hello")); + } + + #[test] + fn test_echo_reply_create() { + let payload = Bytes::from_static(b"world"); + let packet = EchoReplyPacket { + header: Icmpv6Header { + icmpv6_type: Icmpv6Type::EchoReply, + icmpv6_code: Icmpv6Code(0), + checksum: 0, + }, + identifier: 0xabcd, + sequence_number: 0x1234, + payload: payload.clone(), + }; + + let bytes = packet.to_bytes(); + let parsed = EchoReplyPacket::from_bytes(bytes).expect("Failed to parse Echo Reply packet"); + + assert_eq!(parsed.header.icmpv6_type, Icmpv6Type::EchoReply); + assert_eq!(parsed.identifier, 0xabcd); + assert_eq!(parsed.sequence_number, 0x1234); + assert_eq!(parsed.payload, payload); } } diff --git a/nex-packet/src/ip.rs b/nex-packet/src/ip.rs index 9873536..6e0389d 100644 --- a/nex-packet/src/ip.rs +++ b/nex-packet/src/ip.rs @@ -1,13 +1,13 @@ -use crate::PrimitiveValues; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// IP Next Level Protocol +/// IP Next-Level Protocol +/// IPv4: RFC5237 +/// IPv6: RFC7045 #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum IpNextLevelProtocol { +pub enum IpNextProtocol { /// IPv6 Hop-by-Hop Option \[RFC2460\] Hopopt = 0, /// Internet Control Message \[RFC792\] @@ -302,8 +302,8 @@ pub enum IpNextLevelProtocol { Reserved = 255, } -impl IpNextLevelProtocol { - /// IpNextLevelProtocol from u8 +impl IpNextProtocol { + /// IpNextProtocol from u8 pub fn new(n: u8) -> Self { match n { 0 => Self::Hopopt, @@ -456,159 +456,156 @@ impl IpNextLevelProtocol { } pub fn as_str(&self) -> &'static str { match self { - IpNextLevelProtocol::Hopopt => "Hopopt", - IpNextLevelProtocol::Icmp => "Icmp", - IpNextLevelProtocol::Igmp => "Igmp", - IpNextLevelProtocol::Ggp => "Ggp", - IpNextLevelProtocol::Ipv4 => "Ipv4", - IpNextLevelProtocol::St => "St", - IpNextLevelProtocol::Tcp => "Tcp", - IpNextLevelProtocol::Cbt => "Cbt", - IpNextLevelProtocol::Egp => "Egp", - IpNextLevelProtocol::Igp => "Igp", - IpNextLevelProtocol::BbnRccMon => "BbnRccMon", - IpNextLevelProtocol::NvpII => "NvpII", - IpNextLevelProtocol::Pup => "Pup", - IpNextLevelProtocol::Argus => "Argus", - IpNextLevelProtocol::Emcon => "Emcon", - IpNextLevelProtocol::Xnet => "Xnet", - IpNextLevelProtocol::Chaos => "Chaos", - IpNextLevelProtocol::Udp => "Udp", - IpNextLevelProtocol::Mux => "Mux", - IpNextLevelProtocol::DcnMeas => "DcnMeas", - IpNextLevelProtocol::Hmp => "Hmp", - IpNextLevelProtocol::Prm => "Prm", - IpNextLevelProtocol::XnsIdp => "XnsIdp", - IpNextLevelProtocol::Trunk1 => "Trunk1", - IpNextLevelProtocol::Trunk2 => "Trunk2", - IpNextLevelProtocol::Leaf1 => "Leaf1", - IpNextLevelProtocol::Leaf2 => "Leaf2", - IpNextLevelProtocol::Rdp => "Rdp", - IpNextLevelProtocol::Irtp => "Irtp", - IpNextLevelProtocol::IsoTp4 => "IsoTp4", - IpNextLevelProtocol::Netblt => "Netblt", - IpNextLevelProtocol::MfeNsp => "MfeNsp", - IpNextLevelProtocol::MeritInp => "MeritInp", - IpNextLevelProtocol::Dccp => "Dccp", - IpNextLevelProtocol::ThreePc => "ThreePc", - IpNextLevelProtocol::Idpr => "Idpr", - IpNextLevelProtocol::Xtp => "Xtp", - IpNextLevelProtocol::Ddp => "Ddp", - IpNextLevelProtocol::IdprCmtp => "IdprCmtp", - IpNextLevelProtocol::TpPlusPlus => "TpPlusPlus", - IpNextLevelProtocol::Il => "Il", - IpNextLevelProtocol::Ipv6 => "Ipv6", - IpNextLevelProtocol::Sdrp => "Sdrp", - IpNextLevelProtocol::Ipv6Route => "Ipv6Route", - IpNextLevelProtocol::Ipv6Frag => "Ipv6Frag", - IpNextLevelProtocol::Idrp => "Idrp", - IpNextLevelProtocol::Rsvp => "Rsvp", - IpNextLevelProtocol::Gre => "Gre", - IpNextLevelProtocol::Dsr => "Dsr", - IpNextLevelProtocol::Bna => "Bna", - IpNextLevelProtocol::Esp => "Esp", - IpNextLevelProtocol::Ah => "Ah", - IpNextLevelProtocol::INlsp => "INlsp", - IpNextLevelProtocol::Swipe => "Swipe", - IpNextLevelProtocol::Narp => "Narp", - IpNextLevelProtocol::Mobile => "Mobile", - IpNextLevelProtocol::Tlsp => "Tlsp", - IpNextLevelProtocol::Skip => "Skip", - IpNextLevelProtocol::Icmpv6 => "Icmpv6", - IpNextLevelProtocol::Ipv6NoNxt => "Ipv6NoNxt", - IpNextLevelProtocol::Ipv6Opts => "Ipv6Opts", - IpNextLevelProtocol::HostInternal => "HostInternal", - IpNextLevelProtocol::Cftp => "Cftp", - IpNextLevelProtocol::LocalNetwork => "LocalNetwork", - IpNextLevelProtocol::SatExpak => "SatExpak", - IpNextLevelProtocol::Kryptolan => "Kryptolan", - IpNextLevelProtocol::Rvd => "Rvd", - IpNextLevelProtocol::Ippc => "Ippc", - IpNextLevelProtocol::DistributedFs => "DistributedFs", - IpNextLevelProtocol::SatMon => "SatMon", - IpNextLevelProtocol::Visa => "Visa", - IpNextLevelProtocol::Ipcv => "Ipcv", - IpNextLevelProtocol::Cpnx => "Cpnx", - IpNextLevelProtocol::Cphb => "Cphb", - IpNextLevelProtocol::Wsn => "Wsn", - IpNextLevelProtocol::Pvp => "Pvp", - IpNextLevelProtocol::BrSatMon => "BrSatMon", - IpNextLevelProtocol::SunNd => "SunNd", - IpNextLevelProtocol::WbMon => "WbMon", - IpNextLevelProtocol::WbExpak => "WbExpak", - IpNextLevelProtocol::IsoIp => "IsoIp", - IpNextLevelProtocol::Vmtp => "Vmtp", - IpNextLevelProtocol::SecureVmtp => "SecureVmtp", - IpNextLevelProtocol::Vines => "Vines", - IpNextLevelProtocol::TtpOrIptm => "TtpOrIptm", - IpNextLevelProtocol::NsfnetIgp => "NsfnetIgp", - IpNextLevelProtocol::Dgp => "Dgp", - IpNextLevelProtocol::Tcf => "Tcf", - IpNextLevelProtocol::Eigrp => "Eigrp", - IpNextLevelProtocol::OspfigP => "OspfigP", - IpNextLevelProtocol::SpriteRpc => "SpriteRpc", - IpNextLevelProtocol::Larp => "Larp", - IpNextLevelProtocol::Mtp => "Mtp", - IpNextLevelProtocol::Ax25 => "Ax25", - IpNextLevelProtocol::IpIp => "IpIp", - IpNextLevelProtocol::Micp => "Micp", - IpNextLevelProtocol::SccSp => "SccSp", - IpNextLevelProtocol::Etherip => "Etherip", - IpNextLevelProtocol::Encap => "Encap", - IpNextLevelProtocol::PrivEncryption => "PrivEncryption", - IpNextLevelProtocol::Gmtp => "Gmtp", - IpNextLevelProtocol::Ifmp => "Ifmp", - IpNextLevelProtocol::Pnni => "Pnni", - IpNextLevelProtocol::Pim => "Pim", - IpNextLevelProtocol::Aris => "Aris", - IpNextLevelProtocol::Scps => "Scps", - IpNextLevelProtocol::Qnx => "Qnx", - IpNextLevelProtocol::AN => "AN", - IpNextLevelProtocol::IpComp => "IpComp", - IpNextLevelProtocol::Snp => "Snp", - IpNextLevelProtocol::CompaqPeer => "CompaqPeer", - IpNextLevelProtocol::IpxInIp => "IpxInIp", - IpNextLevelProtocol::Vrrp => "Vrrp", - IpNextLevelProtocol::Pgm => "Pgm", - IpNextLevelProtocol::ZeroHop => "ZeroHop", - IpNextLevelProtocol::L2tp => "L2tp", - IpNextLevelProtocol::Ddx => "Ddx", - IpNextLevelProtocol::Iatp => "Iatp", - IpNextLevelProtocol::Stp => "Stp", - IpNextLevelProtocol::Srp => "Srp", - IpNextLevelProtocol::Uti => "Uti", - IpNextLevelProtocol::Smp => "Smp", - IpNextLevelProtocol::Sm => "Sm", - IpNextLevelProtocol::Ptp => "Ptp", - IpNextLevelProtocol::IsisOverIpv4 => "IsisOverIpv4", - IpNextLevelProtocol::Fire => "Fire", - IpNextLevelProtocol::Crtp => "Crtp", - IpNextLevelProtocol::Crudp => "Crudp", - IpNextLevelProtocol::Sscopmce => "Sscopmce", - IpNextLevelProtocol::Iplt => "Iplt", - IpNextLevelProtocol::Sps => "Sps", - IpNextLevelProtocol::Pipe => "Pipe", - IpNextLevelProtocol::Sctp => "Sctp", - IpNextLevelProtocol::Fc => "Fc", - IpNextLevelProtocol::RsvpE2eIgnore => "RsvpE2eIgnore", - IpNextLevelProtocol::MobilityHeader => "MobilityHeader", - IpNextLevelProtocol::UdpLite => "UdpLite", - IpNextLevelProtocol::MplsInIp => "MplsInIp", - IpNextLevelProtocol::Manet => "Manet", - IpNextLevelProtocol::Hip => "Hip", - IpNextLevelProtocol::Shim6 => "Shim6", - IpNextLevelProtocol::Wesp => "Wesp", - IpNextLevelProtocol::Rohc => "Rohc", - IpNextLevelProtocol::Test1 => "Test1", - IpNextLevelProtocol::Test2 => "Test2", - IpNextLevelProtocol::Reserved => "Reserved", + IpNextProtocol::Hopopt => "Hopopt", + IpNextProtocol::Icmp => "Icmp", + IpNextProtocol::Igmp => "Igmp", + IpNextProtocol::Ggp => "Ggp", + IpNextProtocol::Ipv4 => "Ipv4", + IpNextProtocol::St => "St", + IpNextProtocol::Tcp => "Tcp", + IpNextProtocol::Cbt => "Cbt", + IpNextProtocol::Egp => "Egp", + IpNextProtocol::Igp => "Igp", + IpNextProtocol::BbnRccMon => "BbnRccMon", + IpNextProtocol::NvpII => "NvpII", + IpNextProtocol::Pup => "Pup", + IpNextProtocol::Argus => "Argus", + IpNextProtocol::Emcon => "Emcon", + IpNextProtocol::Xnet => "Xnet", + IpNextProtocol::Chaos => "Chaos", + IpNextProtocol::Udp => "Udp", + IpNextProtocol::Mux => "Mux", + IpNextProtocol::DcnMeas => "DcnMeas", + IpNextProtocol::Hmp => "Hmp", + IpNextProtocol::Prm => "Prm", + IpNextProtocol::XnsIdp => "XnsIdp", + IpNextProtocol::Trunk1 => "Trunk1", + IpNextProtocol::Trunk2 => "Trunk2", + IpNextProtocol::Leaf1 => "Leaf1", + IpNextProtocol::Leaf2 => "Leaf2", + IpNextProtocol::Rdp => "Rdp", + IpNextProtocol::Irtp => "Irtp", + IpNextProtocol::IsoTp4 => "IsoTp4", + IpNextProtocol::Netblt => "Netblt", + IpNextProtocol::MfeNsp => "MfeNsp", + IpNextProtocol::MeritInp => "MeritInp", + IpNextProtocol::Dccp => "Dccp", + IpNextProtocol::ThreePc => "ThreePc", + IpNextProtocol::Idpr => "Idpr", + IpNextProtocol::Xtp => "Xtp", + IpNextProtocol::Ddp => "Ddp", + IpNextProtocol::IdprCmtp => "IdprCmtp", + IpNextProtocol::TpPlusPlus => "TpPlusPlus", + IpNextProtocol::Il => "Il", + IpNextProtocol::Ipv6 => "Ipv6", + IpNextProtocol::Sdrp => "Sdrp", + IpNextProtocol::Ipv6Route => "Ipv6Route", + IpNextProtocol::Ipv6Frag => "Ipv6Frag", + IpNextProtocol::Idrp => "Idrp", + IpNextProtocol::Rsvp => "Rsvp", + IpNextProtocol::Gre => "Gre", + IpNextProtocol::Dsr => "Dsr", + IpNextProtocol::Bna => "Bna", + IpNextProtocol::Esp => "Esp", + IpNextProtocol::Ah => "Ah", + IpNextProtocol::INlsp => "INlsp", + IpNextProtocol::Swipe => "Swipe", + IpNextProtocol::Narp => "Narp", + IpNextProtocol::Mobile => "Mobile", + IpNextProtocol::Tlsp => "Tlsp", + IpNextProtocol::Skip => "Skip", + IpNextProtocol::Icmpv6 => "Icmpv6", + IpNextProtocol::Ipv6NoNxt => "Ipv6NoNxt", + IpNextProtocol::Ipv6Opts => "Ipv6Opts", + IpNextProtocol::HostInternal => "HostInternal", + IpNextProtocol::Cftp => "Cftp", + IpNextProtocol::LocalNetwork => "LocalNetwork", + IpNextProtocol::SatExpak => "SatExpak", + IpNextProtocol::Kryptolan => "Kryptolan", + IpNextProtocol::Rvd => "Rvd", + IpNextProtocol::Ippc => "Ippc", + IpNextProtocol::DistributedFs => "DistributedFs", + IpNextProtocol::SatMon => "SatMon", + IpNextProtocol::Visa => "Visa", + IpNextProtocol::Ipcv => "Ipcv", + IpNextProtocol::Cpnx => "Cpnx", + IpNextProtocol::Cphb => "Cphb", + IpNextProtocol::Wsn => "Wsn", + IpNextProtocol::Pvp => "Pvp", + IpNextProtocol::BrSatMon => "BrSatMon", + IpNextProtocol::SunNd => "SunNd", + IpNextProtocol::WbMon => "WbMon", + IpNextProtocol::WbExpak => "WbExpak", + IpNextProtocol::IsoIp => "IsoIp", + IpNextProtocol::Vmtp => "Vmtp", + IpNextProtocol::SecureVmtp => "SecureVmtp", + IpNextProtocol::Vines => "Vines", + IpNextProtocol::TtpOrIptm => "TtpOrIptm", + IpNextProtocol::NsfnetIgp => "NsfnetIgp", + IpNextProtocol::Dgp => "Dgp", + IpNextProtocol::Tcf => "Tcf", + IpNextProtocol::Eigrp => "Eigrp", + IpNextProtocol::OspfigP => "OspfigP", + IpNextProtocol::SpriteRpc => "SpriteRpc", + IpNextProtocol::Larp => "Larp", + IpNextProtocol::Mtp => "Mtp", + IpNextProtocol::Ax25 => "Ax25", + IpNextProtocol::IpIp => "IpIp", + IpNextProtocol::Micp => "Micp", + IpNextProtocol::SccSp => "SccSp", + IpNextProtocol::Etherip => "Etherip", + IpNextProtocol::Encap => "Encap", + IpNextProtocol::PrivEncryption => "PrivEncryption", + IpNextProtocol::Gmtp => "Gmtp", + IpNextProtocol::Ifmp => "Ifmp", + IpNextProtocol::Pnni => "Pnni", + IpNextProtocol::Pim => "Pim", + IpNextProtocol::Aris => "Aris", + IpNextProtocol::Scps => "Scps", + IpNextProtocol::Qnx => "Qnx", + IpNextProtocol::AN => "AN", + IpNextProtocol::IpComp => "IpComp", + IpNextProtocol::Snp => "Snp", + IpNextProtocol::CompaqPeer => "CompaqPeer", + IpNextProtocol::IpxInIp => "IpxInIp", + IpNextProtocol::Vrrp => "Vrrp", + IpNextProtocol::Pgm => "Pgm", + IpNextProtocol::ZeroHop => "ZeroHop", + IpNextProtocol::L2tp => "L2tp", + IpNextProtocol::Ddx => "Ddx", + IpNextProtocol::Iatp => "Iatp", + IpNextProtocol::Stp => "Stp", + IpNextProtocol::Srp => "Srp", + IpNextProtocol::Uti => "Uti", + IpNextProtocol::Smp => "Smp", + IpNextProtocol::Sm => "Sm", + IpNextProtocol::Ptp => "Ptp", + IpNextProtocol::IsisOverIpv4 => "IsisOverIpv4", + IpNextProtocol::Fire => "Fire", + IpNextProtocol::Crtp => "Crtp", + IpNextProtocol::Crudp => "Crudp", + IpNextProtocol::Sscopmce => "Sscopmce", + IpNextProtocol::Iplt => "Iplt", + IpNextProtocol::Sps => "Sps", + IpNextProtocol::Pipe => "Pipe", + IpNextProtocol::Sctp => "Sctp", + IpNextProtocol::Fc => "Fc", + IpNextProtocol::RsvpE2eIgnore => "RsvpE2eIgnore", + IpNextProtocol::MobilityHeader => "MobilityHeader", + IpNextProtocol::UdpLite => "UdpLite", + IpNextProtocol::MplsInIp => "MplsInIp", + IpNextProtocol::Manet => "Manet", + IpNextProtocol::Hip => "Hip", + IpNextProtocol::Shim6 => "Shim6", + IpNextProtocol::Wesp => "Wesp", + IpNextProtocol::Rohc => "Rohc", + IpNextProtocol::Test1 => "Test1", + IpNextProtocol::Test2 => "Test2", + IpNextProtocol::Reserved => "Reserved", } } -} - -impl PrimitiveValues for IpNextLevelProtocol { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - (*self as u8,) + pub fn value(&self) -> u8 { + *self as u8 } } + diff --git a/nex-packet/src/ipv4.rs b/nex-packet/src/ipv4.rs index 2262fb9..9205174 100644 --- a/nex-packet/src/ipv4.rs +++ b/nex-packet/src/ipv4.rs @@ -1,122 +1,23 @@ //! An IPv4 packet abstraction. -use crate::ip::IpNextLevelProtocol; -use crate::PrimitiveValues; - -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; - +use crate::{ip::IpNextProtocol, packet::Packet}; +use bytes::{BufMut, Bytes, BytesMut}; +use nex_core::bitfield::*; use std::net::Ipv4Addr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// IPv4 Header Length -pub const IPV4_HEADER_LEN: usize = MutableIpv4Packet::minimum_packet_size(); +pub const IPV4_HEADER_LEN: usize = 20; /// IPv4 Header Byte Unit (32 bits) pub const IPV4_HEADER_LENGTH_BYTE_UNITS: usize = 4; -/// Represents the IPv4 option header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Ipv4OptionHeader { - pub copied: u1, - pub class: u2, - pub number: Ipv4OptionType, - pub length: Option, -} - -/// Represents the IPv4 header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Ipv4Header { - pub version: u4, - pub header_length: u4, - pub dscp: u6, - pub ecn: u2, - pub total_length: u16be, - pub identification: u16be, - pub flags: u3, - pub fragment_offset: u13be, - pub ttl: u8, - pub next_level_protocol: IpNextLevelProtocol, - pub checksum: u16be, - pub source: Ipv4Addr, - pub destination: Ipv4Addr, - pub options: Vec, -} - -impl Ipv4Header { - /// Construct an IPv4 header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < IPV4_HEADER_LEN { - return Err("Packet is too small for IPv4 header".to_string()); - } - match Ipv4Packet::new(packet) { - Some(ipv4_packet) => Ok(Ipv4Header { - version: ipv4_packet.get_version(), - header_length: ipv4_packet.get_header_length(), - dscp: ipv4_packet.get_dscp(), - ecn: ipv4_packet.get_ecn(), - total_length: ipv4_packet.get_total_length(), - identification: ipv4_packet.get_identification(), - flags: ipv4_packet.get_flags(), - fragment_offset: ipv4_packet.get_fragment_offset(), - ttl: ipv4_packet.get_ttl(), - next_level_protocol: ipv4_packet.get_next_level_protocol(), - checksum: ipv4_packet.get_checksum(), - source: ipv4_packet.get_source(), - destination: ipv4_packet.get_destination(), - options: ipv4_packet - .get_options_iter() - .map(|o| Ipv4OptionHeader { - copied: o.get_copied(), - class: o.get_class(), - number: o.get_number(), - length: o.get_length().first().cloned(), - }) - .collect(), - }), - None => Err("Failed to parse IPv4 packet".to_string()), - } - } - /// Construct an IPv4 header from a Ipv4Packet. - pub(crate) fn from_packet(ipv4_packet: &Ipv4Packet) -> Ipv4Header { - Ipv4Header { - version: ipv4_packet.get_version(), - header_length: ipv4_packet.get_header_length(), - dscp: ipv4_packet.get_dscp(), - ecn: ipv4_packet.get_ecn(), - total_length: ipv4_packet.get_total_length(), - identification: ipv4_packet.get_identification(), - flags: ipv4_packet.get_flags(), - fragment_offset: ipv4_packet.get_fragment_offset(), - ttl: ipv4_packet.get_ttl(), - next_level_protocol: ipv4_packet.get_next_level_protocol(), - checksum: ipv4_packet.get_checksum(), - source: ipv4_packet.get_source(), - destination: ipv4_packet.get_destination(), - options: ipv4_packet - .get_options_iter() - .map(|o| Ipv4OptionHeader { - copied: o.get_copied(), - class: o.get_class(), - number: o.get_number(), - length: o.get_length().first().cloned(), - }) - .collect(), - } - } -} - /// Represents the IPv4 header flags. #[allow(non_snake_case)] #[allow(non_upper_case_globals)] pub mod Ipv4Flags { - use nex_macro_helper::types::*; - + use nex_core::bitfield::*; /// Don't Fragment flag. pub const DontFragment: u3 = 0b010; /// More Fragments flag. @@ -221,47 +122,62 @@ impl Ipv4OptionType { _ => Ipv4OptionType::Unknown(n), } } -} - -impl PrimitiveValues for Ipv4OptionType { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { + pub fn value(&self) -> u8 { match *self { - Ipv4OptionType::EOL => (0,), - Ipv4OptionType::NOP => (1,), - Ipv4OptionType::SEC => (2,), - Ipv4OptionType::LSR => (3,), - Ipv4OptionType::TS => (4,), - Ipv4OptionType::ESEC => (5,), - Ipv4OptionType::CIPSO => (6,), - Ipv4OptionType::RR => (7,), - Ipv4OptionType::SID => (8,), - Ipv4OptionType::SSR => (9,), - Ipv4OptionType::ZSU => (10,), - Ipv4OptionType::MTUP => (11,), - Ipv4OptionType::MTUR => (12,), - Ipv4OptionType::FINN => (13,), - Ipv4OptionType::VISA => (14,), - Ipv4OptionType::ENCODE => (15,), - Ipv4OptionType::IMITD => (16,), - Ipv4OptionType::EIP => (17,), - Ipv4OptionType::TR => (18,), - Ipv4OptionType::ADDEXT => (19,), - Ipv4OptionType::RTRALT => (20,), - Ipv4OptionType::SDB => (21,), - Ipv4OptionType::Unassigned => (22,), - Ipv4OptionType::DPS => (23,), - Ipv4OptionType::UMP => (24,), - Ipv4OptionType::QS => (25,), - Ipv4OptionType::EXP => (30,), - Ipv4OptionType::Unknown(n) => (n,), + Ipv4OptionType::EOL => 0, + Ipv4OptionType::NOP => 1, + Ipv4OptionType::SEC => 2, + Ipv4OptionType::LSR => 3, + Ipv4OptionType::TS => 4, + Ipv4OptionType::ESEC => 5, + Ipv4OptionType::CIPSO => 6, + Ipv4OptionType::RR => 7, + Ipv4OptionType::SID => 8, + Ipv4OptionType::SSR => 9, + Ipv4OptionType::ZSU => 10, + Ipv4OptionType::MTUP => 11, + Ipv4OptionType::MTUR => 12, + Ipv4OptionType::FINN => 13, + Ipv4OptionType::VISA => 14, + Ipv4OptionType::ENCODE => 15, + Ipv4OptionType::IMITD => 16, + Ipv4OptionType::EIP => 17, + Ipv4OptionType::TR => 18, + Ipv4OptionType::ADDEXT => 19, + Ipv4OptionType::RTRALT => 20, + Ipv4OptionType::SDB => 21, + Ipv4OptionType::Unassigned => 22, + Ipv4OptionType::DPS => 23, + Ipv4OptionType::UMP => 24, + Ipv4OptionType::QS => 25, + Ipv4OptionType::EXP => 30, + Ipv4OptionType::Unknown(n) => n, } } } -/// Represents an IPv4 Packet. -#[packet] -pub struct Ipv4 { +/// Represents the IPv4 option header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Ipv4OptionHeader { + pub copied: u1, + pub class: u2, + pub number: Ipv4OptionType, + pub length: Option, +} + +/// Represents the IPv4 Option field. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Ipv4OptionPacket { + pub header: Ipv4OptionHeader, + pub data: Bytes, +} + +/// Represents the IPv4 header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Ipv4Header { pub version: u4, pub header_length: u4, pub dscp: u6, @@ -271,272 +187,390 @@ pub struct Ipv4 { pub flags: u3, pub fragment_offset: u13be, pub ttl: u8, - #[construct_with(u8)] - pub next_level_protocol: IpNextLevelProtocol, + pub next_level_protocol: IpNextProtocol, pub checksum: u16be, - #[construct_with(u8, u8, u8, u8)] pub source: Ipv4Addr, - #[construct_with(u8, u8, u8, u8)] pub destination: Ipv4Addr, - #[length_fn = "ipv4_options_length"] - pub options: Vec, - #[length_fn = "ipv4_payload_length"] - #[payload] - pub payload: Vec, + pub options: Vec, } -/// Calculates a checksum of an IPv4 packet header. -/// The checksum field of the packet is regarded as zeros during the calculation. -pub fn checksum(packet: &Ipv4Packet) -> u16be { - use crate::util; - use crate::Packet; - - let min = Ipv4Packet::minimum_packet_size(); - let max = packet.packet().len(); - let header_length = match packet.get_header_length() as usize * 4 { - length if length < min => min, - length if length > max => max, - length => length, - }; - let data = &packet.packet()[..header_length]; - util::checksum(data, 5) +/// Represents an IPv4 Packet. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Ipv4Packet { + pub header: Ipv4Header, + pub payload: Bytes, } -#[cfg(test)] -mod checksum_tests { - use super::*; - use alloc::vec; +impl Packet for Ipv4Packet { + type Header = Ipv4Header; + + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < IPV4_HEADER_LEN { + return None; + } - #[test] - fn checksum_zeros() { - let mut data = vec![0; 20]; - let expected = 64255; - let mut pkg = MutableIpv4Packet::new(&mut data[..]).unwrap(); - pkg.set_header_length(5); - assert_eq!(checksum(&pkg.to_immutable()), expected); - pkg.set_checksum(123); - assert_eq!(checksum(&pkg.to_immutable()), expected); - } + let version = (bytes[0] & 0xF0) >> 4; + let header_length = (bytes[0] & 0x0F) as usize; + let total_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize; - #[test] - fn checksum_nonzero() { - let mut data = vec![255; 20]; - let expected = 2560; - let mut pkg = MutableIpv4Packet::new(&mut data[..]).unwrap(); - pkg.set_header_length(5); - assert_eq!(checksum(&pkg.to_immutable()), expected); - pkg.set_checksum(123); - assert_eq!(checksum(&pkg.to_immutable()), expected); - } + if bytes.len() < total_length || header_length < 5 { + return None; + } - #[test] - fn checksum_too_small_header_length() { - let mut data = vec![148; 20]; - let expected = 51910; - let mut pkg = MutableIpv4Packet::new(&mut data[..]).unwrap(); - pkg.set_header_length(0); - assert_eq!(checksum(&pkg.to_immutable()), expected); - } + let ihl_bytes = header_length * 4; + if ihl_bytes < IPV4_HEADER_LEN || ihl_bytes > total_length { + return None; + } + let payload = Bytes::copy_from_slice(&bytes[ihl_bytes..total_length]); + + let mut options = Vec::new(); + let mut i = IPV4_HEADER_LEN; + + while i < ihl_bytes { + let b = bytes[i]; + let copied = (b >> 7) & 0x01; + let class = (b >> 5) & 0x03; + let number = Ipv4OptionType::new(b & 0b0001_1111); + + match number { + Ipv4OptionType::EOL => { + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: None, + }, + data: Bytes::new(), + }); + break; + } + Ipv4OptionType::NOP => { + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: None, + }, + data: Bytes::new(), + }); + i += 1; + } + _ => { + if i + 2 > ihl_bytes { + break; + } + let len = bytes[i + 1] as usize; + if len < 2 || i + len > ihl_bytes { + break; + } + + let data = Bytes::copy_from_slice(&bytes[i + 2..i + len]); + + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: Some(len as u8), + }, + data, + }); + + i += len; + } + } + } - #[test] - fn checksum_too_large_header_length() { - let mut data = vec![148; 20]; - let expected = 51142; - let mut pkg = MutableIpv4Packet::new(&mut data[..]).unwrap(); - pkg.set_header_length(99); - assert_eq!(checksum(&pkg.to_immutable()), expected); + Some(Self { + header: Ipv4Header { + version: version as u4, + header_length: header_length as u4, + dscp: (bytes[1] >> 2) as u6, + ecn: (bytes[1] & 0x03) as u2, + total_length: u16::from_be_bytes([bytes[2], bytes[3]]) as u16be, + identification: u16::from_be_bytes([bytes[4], bytes[5]]) as u16be, + flags: (bytes[6] >> 5) as u3, + fragment_offset: ((u16::from_be_bytes([bytes[6], bytes[7]])) & 0x1FFF) as u13be, + ttl: bytes[8], + next_level_protocol: IpNextProtocol::new(bytes[9]), + checksum: u16::from_be_bytes([bytes[10], bytes[11]]) as u16be, + source: Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15]), + destination: Ipv4Addr::new(bytes[16], bytes[17], bytes[18], bytes[19]), + options, + }, + payload, + }) } -} - -fn ipv4_options_length(ipv4: &Ipv4Packet) -> usize { - // the header_length unit is the "word" - // - and a word is made of 4 bytes, - // - and the header length (without the options) is 5 words long - (ipv4.get_header_length() as usize * 4).saturating_sub(20) -} - -#[test] -fn ipv4_options_length_test() { - let mut packet = [0u8; 20]; - let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_header.set_header_length(5); - assert_eq!(ipv4_options_length(&ip_header.to_immutable()), 0); -} - -fn ipv4_payload_length(ipv4: &Ipv4Packet) -> usize { - (ipv4.get_total_length() as usize).saturating_sub(ipv4.get_header_length() as usize * 4) -} - -#[test] -fn ipv4_payload_length_test() { - let mut packet = [0u8; 30]; - let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_header.set_header_length(5); - ip_header.set_total_length(20); - assert_eq!(ipv4_payload_length(&ip_header.to_immutable()), 0); - // just comparing with 0 is prone to false positives in this case. - // for instance if one forgets to set total_length, one always gets 0 - ip_header.set_total_length(30); - assert_eq!(ipv4_payload_length(&ip_header.to_immutable()), 10); -} - -/// Represents the IPv4 Option field. -#[packet] -pub struct Ipv4Option { - copied: u1, - class: u2, - #[construct_with(u5)] - number: Ipv4OptionType, - #[length_fn = "ipv4_option_length"] - // The length field is an optional field, using a Vec is a way to implement - // it - length: Vec, - #[length_fn = "ipv4_option_payload_length"] - #[payload] - data: Vec, -} - -/// This function gets the 'length' of the length field of the IPv4Option packet -/// Few options (EOOL, NOP) are 1 bytes long, and then have a length field equal -/// to 0. -fn ipv4_option_length(option: &Ipv4OptionPacket) -> usize { - match option.get_number() { - Ipv4OptionType::EOL => 0, - Ipv4OptionType::NOP => 0, - _ => 1, + + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) } -} -fn ipv4_option_payload_length(ipv4_option: &Ipv4OptionPacket) -> usize { - match ipv4_option.get_length().first() { - Some(len) => (*len as usize).saturating_sub(2), - None => 0, - } -} + fn to_bytes(&self) -> Bytes { + // 1. Version/IHL + DSCP/ECN + let mut tmp_buf = BytesMut::with_capacity(60); // max header size + for option in &self.header.options { + let number = option.header.number.value(); + let type_byte = (option.header.copied << 7) + | (option.header.class << 5) + | (number & 0b0001_1111); + tmp_buf.put_u8(type_byte); + + match option.header.number { + Ipv4OptionType::EOL | Ipv4OptionType::NOP => {} + _ => { + let len = option.header.length.unwrap_or((option.data.len() + 2) as u8); + tmp_buf.put_u8(len); + tmp_buf.extend_from_slice(&option.data); + } + } + } -#[test] -fn ipv4_packet_test() { - use crate::ip::IpNextLevelProtocol; - use crate::Packet; - use crate::PacketSize; + // padding + while tmp_buf.len() % 4 != 0 { + tmp_buf.put_u8(0); + } - let mut packet = [0u8; 200]; - { - let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_header.set_version(4); - assert_eq!(ip_header.get_version(), 4); + let header_len = IPV4_HEADER_LEN + tmp_buf.len(); + + let total_len_expected = header_len + self.payload.len(); + // Check if the total length exceeds the header's total_length field + if total_len_expected > self.header.total_length as usize { + panic!( + "Payload too long: header {} + payload {} = {} > total_length {}", + header_len, + self.payload.len(), + total_len_expected, + self.header.total_length + ); + } - ip_header.set_header_length(5); - assert_eq!(ip_header.get_header_length(), 5); + let header_len_words = (header_len / 4) as u8; - ip_header.set_dscp(4); - assert_eq!(ip_header.get_dscp(), 4); + let mut buf = BytesMut::with_capacity(self.total_len()); - ip_header.set_ecn(1); - assert_eq!(ip_header.get_ecn(), 1); + buf.put_u8((self.header.version << 4 | header_len_words) as u8); + buf.put_u8((self.header.dscp << 2 | self.header.ecn) as u8); - ip_header.set_total_length(115); - assert_eq!(ip_header.get_total_length(), 115); - assert_eq!(95, ip_header.payload().len()); - assert_eq!(ip_header.get_total_length(), ip_header.packet_size() as u16); + // 2. Fixed header fields + buf.put_u16(self.header.total_length); + buf.put_u16(self.header.identification); + buf.put_u16(((self.header.flags as u16) << 13) | self.header.fragment_offset); + buf.put_u8(self.header.ttl); + buf.put_u8(self.header.next_level_protocol.value()); + buf.put_u16(self.header.checksum); + buf.extend_from_slice(&self.header.source.octets()); + buf.extend_from_slice(&self.header.destination.octets()); - ip_header.set_identification(257); - assert_eq!(ip_header.get_identification(), 257); + // 3. options + buf.extend_from_slice(&tmp_buf); - ip_header.set_flags(Ipv4Flags::DontFragment as u3); - assert_eq!(ip_header.get_flags(), 2); + // 4. payload + buf.extend_from_slice(&self.payload); - ip_header.set_fragment_offset(257); - assert_eq!(ip_header.get_fragment_offset(), 257); + buf.freeze() + } - ip_header.set_ttl(64); - assert_eq!(ip_header.get_ttl(), 64); + fn header(&self) -> Bytes { + self.to_bytes().slice(..self.header_len()) + } - ip_header.set_next_level_protocol(IpNextLevelProtocol::Udp); - assert_eq!( - ip_header.get_next_level_protocol(), - IpNextLevelProtocol::Udp - ); + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + self.header.header_length as usize * 4 + } - ip_header.set_source(Ipv4Addr::new(192, 168, 0, 1)); - assert_eq!(ip_header.get_source(), Ipv4Addr::new(192, 168, 0, 1)); + fn payload_len(&self) -> usize { + self.payload.len() + } - ip_header.set_destination(Ipv4Addr::new(192, 168, 0, 199)); - assert_eq!(ip_header.get_destination(), Ipv4Addr::new(192, 168, 0, 199)); + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } - let imm_header = checksum(&ip_header.to_immutable()); - ip_header.set_checksum(imm_header); - assert_eq!(ip_header.get_checksum(), 0xb64e); + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) } +} - let ref_packet = [ - 0x45, /* ver/ihl */ - 0x11, /* dscp/ecn */ - 0x00, 0x73, /* total len */ - 0x01, 0x01, /* identification */ - 0x41, 0x01, /* flags/frag offset */ - 0x40, /* ttl */ - 0x11, /* proto */ - 0xb6, 0x4e, /* checksum */ - 0xc0, 0xa8, 0x00, 0x01, /* source ip */ - 0xc0, 0xa8, 0x00, 0xc7, /* dest ip */ - ]; - - assert_eq!(&ref_packet[..], &packet[..ref_packet.len()]); +impl Ipv4Packet { + pub fn with_computed_checksum(mut self) -> Self { + self.header.checksum = checksum(&self); + self + } } -#[test] -fn ipv4_packet_option_test() { - use alloc::vec; +/// Calculates a checksum of an IPv4 packet header. +/// The checksum field of the packet is regarded as zeros during the calculation. +pub fn checksum(packet: &Ipv4Packet) -> u16be { + use crate::util; - let mut packet = [0u8; 3]; - { - let mut ipv4_options = MutableIpv4OptionPacket::new(&mut packet[..]).unwrap(); + let bytes = packet.to_bytes(); + let len = packet.header_len(); + util::checksum(&bytes[..len], 5) +} - ipv4_options.set_copied(1); - assert_eq!(ipv4_options.get_copied(), 1); +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_ipv4_packet_round_trip() { + let raw = Bytes::from_static(&[ + 0x45, 0x00, 0x00, 0x1c, // Version + IHL, DSCP + ECN, Total Length (28) + 0x1c, 0x46, 0x40, 0x00, // Identification, Flags + Fragment Offset + 0x40, 0x06, 0xb1, 0xe6, // TTL, Protocol (TCP), Header checksum + 0xc0, 0xa8, 0x00, 0x01, // Source: 192.168.0.1 + 0xc0, 0xa8, 0x00, 0xc7, // Destination: 192.168.0.199 + // Payload (8 bytes) + 0xde, 0xad, 0xbe, 0xef, + 0xca, 0xfe, 0xba, 0xbe, + ]); + + let packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse Ipv4Packet"); + assert_eq!(packet.header.version, 4); + assert_eq!(packet.header.header_length, 5); + assert_eq!(packet.header.total_length, 28u16); + assert_eq!(packet.header.source, Ipv4Addr::new(192, 168, 0, 1)); + assert_eq!(packet.header.destination, Ipv4Addr::new(192, 168, 0, 199)); + assert_eq!(packet.payload, Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe])); + + let serialized = packet.to_bytes(); + assert_eq!(&serialized[..], &raw[..]); + } - ipv4_options.set_class(0); - assert_eq!(ipv4_options.get_class(), 0); + #[test] + fn test_ipv4_packet_with_options_round_trip() { + let raw = Bytes::from_static(&[ + // IPv4 header (20bytes + 8bytes option + 4bytes payload = 32bytes -> IHL=7) + 0x47, 0x00, 0x00, 0x20, // [0-3] Version(4), IHL(7=28bytes), DSCP/ECN, Total Length=32 bytes + 0x12, 0x34, 0x40, 0x00, // [4-7] Identification, Flags=DF(0x40), Fragment Offset + 0x40, 0x11, 0x00, 0x00, // [8-11] TTL=64, Protocol=17(UDP), Header Checksum (0 for now) + 0xc0, 0xa8, 0x00, 0x01, // [12-15] Source IP = 192.168.0.1 + 0xc0, 0xa8, 0x00, 0x02, // [16-19] Destination IP = 192.168.0.2 + + // IPv4 options (8bytes) + // Option 1: 1byte NOP + 0x01, // [20] NOP (No Operation) + + // Option 2: 4bytes + 0x87, 0x04, 0x12, 0x34, // [21-24] Option Type=RR(7), Copied=1, Class=0, Length=4, Data=[0x12, 0x34] + + // Option 3: EOL (End of Options List) with padding + 0x00, // [25] EOL (End of Options List) + 0x00, // [26] Padding + 0x00, // [27] Padding + + // Payload 4bytes + 0xde, 0xad, 0xbe, 0xef, // [28-31] Payload: deadbeef + ]); + + let packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse Ipv4Packet"); + + assert_eq!(packet.header.version, 4); + assert_eq!(packet.header.header_length, 7); + assert_eq!(packet.header.total_length, 32); + assert_eq!(packet.header.source, Ipv4Addr::new(192, 168, 0, 1)); + assert_eq!(packet.header.destination, Ipv4Addr::new(192, 168, 0, 2)); - ipv4_options.set_number(Ipv4OptionType::new(3)); - assert_eq!(ipv4_options.get_number(), Ipv4OptionType::LSR); + assert_eq!( + packet.payload, + Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]) + ); - ipv4_options.set_length(&vec![3]); - assert_eq!(ipv4_options.get_length(), vec![3]); + assert_eq!(packet.header.options.len(), 3); + assert_eq!(packet.header.options[0].header.number, Ipv4OptionType::NOP); + assert_eq!(packet.header.options[1].header.copied, 1); + assert_eq!(packet.header.options[1].header.class, 0); + assert_eq!(packet.header.options[1].header.number, Ipv4OptionType::RR); + assert_eq!(packet.header.options[1].header.number.value(), 7); + assert_eq!(packet.header.options[1].header.length, Some(4)); + assert_eq!(packet.header.options[1].data.as_ref(), &[0x12, 0x34]); + assert_eq!(packet.header.options[2].header.number, Ipv4OptionType::EOL); + + let serialized = packet.to_bytes(); + assert_eq!(&serialized[..], &raw[..]); + } - ipv4_options.set_data(&vec![16]); + #[test] + fn ipv4_option_packet_test() { + let option = Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied: 1, + class: 0, + number: Ipv4OptionType::LSR, + length: Some(3), + }, + data: Bytes::from_static(&[0x10]), + }; + + let mut buf = BytesMut::new(); + let ty = (option.header.copied << 7) | (option.header.class << 5) | (option.header.number.value() & 0x1F); + buf.put_u8(ty); + buf.put_u8(3); + buf.put_slice(&[0x10]); + + assert_eq!(buf.freeze(), Bytes::from_static(&[0x83, 0x03, 0x10])); } - let ref_packet = [ - 0x83, /* copy / class / number */ - 0x03, /* length */ - 0x10, /* data */ - ]; + #[test] + #[should_panic(expected = "Payload too long")] + fn ipv4_payload_too_long_should_panic() { + let packet = Ipv4Packet { + header: Ipv4Header { + version: 4, + header_length: 5, + dscp: 0, + ecn: 0, + total_length: 24, // Header 20 + payload 4 = 24 but ... + identification: 0, + flags: 0, + fragment_offset: 0, + ttl: 64, + next_level_protocol: IpNextProtocol::Udp, + checksum: 0, + source: Ipv4Addr::LOCALHOST, + destination: Ipv4Addr::LOCALHOST, + options: vec![], + }, + payload: Bytes::from_static(&[0, 1, 2, 3, 4, 5]), // 6 bytes payload + }; + + // This should panic because the payload length exceeds the total_length specified in the header + let _ = packet.to_bytes(); + } - assert_eq!(&ref_packet[..], &packet[..]); + #[test] + fn test_ipv4_checksum() { + let raw = Bytes::from_static(&[ + 0x45, 0x00, 0x00, 0x14, + 0x00, 0x00, 0x40, 0x00, + 0x40, 0x06, 0x00, 0x00, // checksum: 0 + 0x0a, 0x00, 0x00, 0x01, + 0x0a, 0x00, 0x00, 0x02, + ]); + + let mut packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse"); + let computed = checksum(&packet); + packet.header.checksum = computed; + + let serialized = packet.to_bytes(); + let reparsed = Ipv4Packet::from_bytes(serialized).expect("Reparse failed"); + + // Check if the checksum matches + assert_eq!(reparsed.header.checksum, computed); + + // Check if the serialized bytes match the original raw bytes + let mut raw_copy = raw.to_vec(); + raw_copy[10] = (computed >> 8) as u8; + raw_copy[11] = (computed & 0xff) as u8; + assert_eq!(&packet.to_bytes()[..], &raw_copy[..]); + } } -#[test] -fn ipv4_packet_set_payload_test() { - use crate::Packet; - - let mut packet = [0u8; 25]; // allow 20 byte header and 5 byte payload - let mut ip_packet = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_packet.set_total_length(25); - ip_packet.set_header_length(5); - let payload = b"stuff"; // 5 bytes - ip_packet.set_payload(&payload[..]); - assert_eq!(ip_packet.payload(), payload); -} -#[test] -#[should_panic(expected = "index 25 out of range for slice of length 24")] -fn ipv4_packet_set_payload_test_panic() { - let mut packet = [0u8; 24]; // allow 20 byte header and 4 byte payload - let mut ip_packet = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_packet.set_total_length(25); - ip_packet.set_header_length(5); - let payload = b"stuff"; // 5 bytes - ip_packet.set_payload(&payload[..]); // panic -} diff --git a/nex-packet/src/ipv6.rs b/nex-packet/src/ipv6.rs index 48c7c17..58fa680 100644 --- a/nex-packet/src/ipv6.rs +++ b/nex-packet/src/ipv6.rs @@ -1,349 +1,547 @@ -//! An IPv6 packet abstraction. - -use crate::ip::IpNextLevelProtocol; - -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; - use std::net::Ipv6Addr; +use bytes::{Bytes, BytesMut, BufMut}; +use crate::packet::Packet; +use crate::ip::IpNextProtocol; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// IPv6 Header Length. -pub const IPV6_HEADER_LEN: usize = MutableIpv6Packet::minimum_packet_size(); +pub const IPV6_HEADER_LEN: usize = 40; -/// Represents the IPv6 header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Ipv6Header { - pub version: u4, - pub traffic_class: u8, - pub flow_label: u20be, - pub payload_length: u16be, - pub next_header: IpNextLevelProtocol, - pub hop_limit: u8, - pub source: Ipv6Addr, - pub destination: Ipv6Addr, -} - -impl Ipv6Header { - /// Construct an IPv6 header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < IPV6_HEADER_LEN { - return Err("Packet is too small for IPv6 header".to_string()); - } - match Ipv6Packet::new(packet) { - Some(ipv6_packet) => Ok(Ipv6Header { - version: ipv6_packet.get_version(), - traffic_class: ipv6_packet.get_traffic_class(), - flow_label: ipv6_packet.get_flow_label(), - payload_length: ipv6_packet.get_payload_length(), - next_header: ipv6_packet.get_next_header(), - hop_limit: ipv6_packet.get_hop_limit(), - source: ipv6_packet.get_source(), - destination: ipv6_packet.get_destination(), - }), - None => Err("Failed to parse IPv6 packet".to_string()), - } - } - /// Construct an IPv6 header from a Ipv6Packet. - pub(crate) fn from_packet(ipv6_packet: &Ipv6Packet) -> Ipv6Header { - Ipv6Header { - version: ipv6_packet.get_version(), - traffic_class: ipv6_packet.get_traffic_class(), - flow_label: ipv6_packet.get_flow_label(), - payload_length: ipv6_packet.get_payload_length(), - next_header: ipv6_packet.get_next_header(), - hop_limit: ipv6_packet.get_hop_limit(), - source: ipv6_packet.get_source(), - destination: ipv6_packet.get_destination(), - } - } -} - -/// Represents an IPv6 Packet. -#[packet] -pub struct Ipv6 { - pub version: u4, - pub traffic_class: u8, - pub flow_label: u20be, - pub payload_length: u16be, - #[construct_with(u8)] - pub next_header: IpNextLevelProtocol, + pub version: u8, // 4 bits + pub traffic_class: u8, // 8 bits + pub flow_label: u32, // 20 bits + pub payload_length: u16, + pub next_header: IpNextProtocol, pub hop_limit: u8, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] pub source: Ipv6Addr, - #[construct_with(u16, u16, u16, u16, u16, u16, u16, u16)] pub destination: Ipv6Addr, - #[length = "payload_length"] - #[payload] - pub payload: Vec, } -impl<'p> ExtensionIterable<'p> { - pub fn new(buf: &[u8]) -> ExtensionIterable { - ExtensionIterable { buf: buf } - } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ipv6Packet { + pub header: Ipv6Header, + pub extensions: Vec, + pub payload: Bytes, } -/// Represents an IPv6 Extension. -#[packet] -pub struct Extension { - #[construct_with(u8)] - pub next_header: IpNextLevelProtocol, - pub hdr_ext_len: u8, - #[length_fn = "ipv6_extension_length"] - #[payload] - pub options: Vec, -} +impl Packet for Ipv6Packet { + type Header = Ipv6Header; + + fn from_buf(bytes: &[u8]) -> Option { + if bytes.len() < IPV6_HEADER_LEN { + return None; + } -fn ipv6_extension_length(ext: &ExtensionPacket) -> usize { - ext.get_hdr_ext_len() as usize * 8 + 8 - 2 -} + // --- Parse the header section --- + let version_traffic_flow = &bytes[..4]; + let version = version_traffic_flow[0] >> 4; + let traffic_class = ((version_traffic_flow[0] & 0x0F) << 4) | (version_traffic_flow[1] >> 4); + let flow_label = u32::from(version_traffic_flow[1] & 0x0F) << 16 + | u32::from(version_traffic_flow[2]) << 8 + | u32::from(version_traffic_flow[3]); + + let payload_length = u16::from_be_bytes([bytes[4], bytes[5]]); + let mut next_header = IpNextProtocol::new(bytes[6]); + let hop_limit = bytes[7]; + + let source = Ipv6Addr::from(<[u8; 16]>::try_from(&bytes[8..24]).ok()?); + let destination = Ipv6Addr::from(<[u8; 16]>::try_from(&bytes[24..40]).ok()?); + + let header = Ipv6Header { + version, + traffic_class, + flow_label, + payload_length, + next_header, + hop_limit, + source, + destination, + }; -/// Represents an IPv6 Hop-by-Hop Options. -pub type HopByHop = Extension; -/// A structure enabling manipulation of on the wire packets. -pub type HopByHopPacket<'p> = ExtensionPacket<'p>; -/// A structure enabling manipulation of on the wire packets. -pub type MutableHopByHopPacket<'p> = MutableExtensionPacket<'p>; - -/// Represents an IPv6 Routing Extension. -#[packet] -pub struct Routing { - #[construct_with(u8)] - pub next_header: IpNextLevelProtocol, - pub hdr_ext_len: u8, - pub routing_type: u8, - pub segments_left: u8, - #[length_fn = "routing_extension_length"] - #[payload] - pub data: Vec, -} + // --- Walk through the extension headers --- + let mut offset = IPV6_HEADER_LEN; + let mut extensions = Vec::new(); + + loop { + match next_header { + IpNextProtocol::Hopopt | IpNextProtocol::Ipv6Route | IpNextProtocol::Ipv6Frag + | IpNextProtocol::Ipv6Opts => { + if offset + 2 > bytes.len() { + return None; + } + + let nh = IpNextProtocol::new(bytes[offset]); + let ext_len = bytes[offset + 1] as usize; + + match next_header { + IpNextProtocol::Hopopt | IpNextProtocol::Ipv6Opts => { + let total_len = 8 + ext_len * 8; + if offset + total_len > bytes.len() { + return None; + } + + let data = Bytes::copy_from_slice(&bytes[offset + 2 .. offset + total_len]); + let ext = match next_header { + IpNextProtocol::Hopopt => Ipv6ExtensionHeader::HopByHop { next: nh, data }, + IpNextProtocol::Ipv6Opts => Ipv6ExtensionHeader::Destination { next: nh, data }, + _ => Ipv6ExtensionHeader::Raw { + next: nh, + raw: Bytes::copy_from_slice(&bytes[offset .. offset + total_len]), + }, + }; + + extensions.push(ext); + next_header = nh; + offset += total_len; + } + + IpNextProtocol::Ipv6Route => { + if offset + 4 > bytes.len() { + return None; + } + + let routing_type = bytes[offset + 2]; + let segments_left = bytes[offset + 3]; + let total_len = 8 + ext_len * 8; + if offset + total_len > bytes.len() { + return None; + } + + let data = Bytes::copy_from_slice(&bytes[offset + 4 .. offset + total_len]); + extensions.push(Ipv6ExtensionHeader::Routing { + next: nh, + routing_type, + segments_left, + data, + }); + + next_header = nh; + offset += total_len; + } + + IpNextProtocol::Ipv6Frag => { + if offset + 8 > bytes.len() { + return None; + } + + //let reserved = bytes[offset + 1]; + let frag_off_flags = u16::from_be_bytes([bytes[offset + 2], bytes[offset + 3]]); + let offset_val = frag_off_flags >> 3; + let more = (frag_off_flags & 0x1) != 0; + let id = u32::from_be_bytes([ + bytes[offset + 4], bytes[offset + 5], + bytes[offset + 6], bytes[offset + 7], + ]); + + extensions.push(Ipv6ExtensionHeader::Fragment { + next: nh, + offset: offset_val, + more, + id, + }); + + next_header = nh; + offset += 8; + } + + _ => break, + } + } + + _ => break, + } + } -fn routing_extension_length(ext: &RoutingPacket) -> usize { - ext.get_hdr_ext_len() as usize * 8 + 8 - 4 -} + let payload = Bytes::copy_from_slice(&bytes[offset..]); + Some(Ipv6Packet { + header, + extensions, + payload, + }) + } + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } -/// Represents an IPv6 Fragment Extension. -#[packet] -pub struct Fragment { - #[construct_with(u8)] - pub next_header: IpNextLevelProtocol, - pub reserved: u8, - pub fragment_offset_with_flags: u16be, - pub id: u32be, - #[length = "0"] - #[payload] - pub payload: Vec, -} + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(self.total_len()); + + // --- 1. Basic header (first 40 bytes) --- + let vtf_1 = (self.header.version << 4) | (self.header.traffic_class >> 4); + let vtf_2 = ((self.header.traffic_class & 0x0F) << 4) | ((self.header.flow_label >> 16) as u8); + let vtf_3 = (self.header.flow_label >> 8) as u8; + let vtf_4 = self.header.flow_label as u8; + + buf.put_u8(vtf_1); + buf.put_u8(vtf_2); + buf.put_u8(vtf_3); + buf.put_u8(vtf_4); + buf.put_u16(self.header.payload_length); + // First next_header (first extension header if present) + let first_next_header = self.extensions.first() + .map(|ext| ext.next_protocol()) + .unwrap_or(self.header.next_header); + buf.put_u8(first_next_header.value()); + buf.put_u8(self.header.hop_limit); + buf.extend_from_slice(&self.header.source.octets()); + buf.extend_from_slice(&self.header.destination.octets()); + + // --- 2. Encode the extension header chain --- + for ext in &self.extensions { + match ext { + Ipv6ExtensionHeader::HopByHop { next, data } + | Ipv6ExtensionHeader::Destination { next, data } => { + let hdr_ext_len = ((data.len() + 6) / 8) as u8 - 1; + buf.put_u8(next.value()); + buf.put_u8(hdr_ext_len); + buf.extend_from_slice(data); + // Padding (8 byte alignment) + while (2 + data.len()) % 8 != 0 { + buf.put_u8(0); + } + } + + Ipv6ExtensionHeader::Routing { + next, + routing_type, + segments_left, + data, + } => { + let hdr_ext_len = ((data.len() + 4 + 6) / 8) as u8 - 1; + buf.put_u8(next.value()); + buf.put_u8(hdr_ext_len); + buf.put_u8(*routing_type); + buf.put_u8(*segments_left); + buf.extend_from_slice(data); + while (4 + data.len()) % 8 != 0 { + buf.put_u8(0); + } + } + + Ipv6ExtensionHeader::Fragment { next, offset, more, id } => { + buf.put_u8(next.value()); + buf.put_u8(0); // reserved + let offset_flags = (offset << 3) | if *more { 1 } else { 0 }; + buf.put_u16(offset_flags); + buf.put_u32(*id); + } + + Ipv6ExtensionHeader::Raw { next: _, raw } => { + // Note: assume the raw header already includes the next field + buf.extend_from_slice(&raw[..]); + } + } + } -const FRAGMENT_FLAGS_MASK: u16 = 0x03; -const FRAGMENT_FLAGS_MORE_FRAGMENTS: u16 = 0x01; -const FRAGMENT_OFFSET_MASK: u16 = !FRAGMENT_FLAGS_MASK; + // --- 3. Payload --- + buf.extend_from_slice(&self.payload); -impl<'p> FragmentPacket<'p> { - pub fn get_fragment_offset(&self) -> u16 { - self.get_fragment_offset_with_flags() & FRAGMENT_OFFSET_MASK + buf.freeze() } - pub fn is_last_fragment(&self) -> bool { - (self.get_fragment_offset_with_flags() & FRAGMENT_FLAGS_MORE_FRAGMENTS) == 0 + fn header(&self) -> Bytes { + self.to_bytes().slice(..IPV6_HEADER_LEN) } -} -impl<'p> MutableFragmentPacket<'p> { - pub fn get_fragment_offset(&self) -> u16 { - self.get_fragment_offset_with_flags() & FRAGMENT_OFFSET_MASK + fn payload(&self) -> Bytes { + self.payload.clone() } - pub fn is_last_fragment(&self) -> bool { - (self.get_fragment_offset_with_flags() & FRAGMENT_FLAGS_MORE_FRAGMENTS) == 0 + fn header_len(&self) -> usize { + IPV6_HEADER_LEN } - pub fn set_fragment_offset(&mut self, offset: u16) { - let fragment_offset_with_flags = self.get_fragment_offset_with_flags(); - - self.set_fragment_offset_with_flags( - (offset & FRAGMENT_OFFSET_MASK) | (fragment_offset_with_flags & FRAGMENT_FLAGS_MASK), - ); + fn payload_len(&self) -> usize { + self.payload.len() } - pub fn set_last_fragment(&mut self, is_last: bool) { - let fragment_offset_with_flags = self.get_fragment_offset_with_flags(); + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } - self.set_fragment_offset_with_flags(if is_last { - fragment_offset_with_flags & !FRAGMENT_FLAGS_MORE_FRAGMENTS - } else { - fragment_offset_with_flags | FRAGMENT_FLAGS_MORE_FRAGMENTS - }); + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) } } -/// Represents an Destination Options. -pub type Destination = Extension; -/// A structure enabling manipulation of on the wire packets. -pub type DestinationPacket<'p> = ExtensionPacket<'p>; -/// A structure enabling manipulation of on the wire packets. -pub type MutableDestinationPacket<'p> = MutableExtensionPacket<'p>; - -#[test] -fn ipv6_header_test() { - use crate::ip::IpNextLevelProtocol; - use crate::{MutablePacket, Packet, PacketSize}; - use alloc::vec; - - let mut packet = [0u8; 0x200]; - { - let mut ip_header = MutableIpv6Packet::new(&mut packet[..]).unwrap(); - ip_header.set_version(6); - assert_eq!(ip_header.get_version(), 6); - - ip_header.set_traffic_class(17); - assert_eq!(ip_header.get_traffic_class(), 17); - - ip_header.set_flow_label(0x10101); - assert_eq!(ip_header.get_flow_label(), 0x10101); - - ip_header.set_payload_length(0x0101); - assert_eq!(ip_header.get_payload_length(), 0x0101); - assert_eq!(0x0101, ip_header.payload().len()); - - ip_header.set_next_header(IpNextLevelProtocol::Hopopt); - assert_eq!(ip_header.get_next_header(), IpNextLevelProtocol::Hopopt); - - ip_header.set_hop_limit(1); - assert_eq!(ip_header.get_hop_limit(), 1); - - let source = Ipv6Addr::new(0x110, 0x1001, 0x110, 0x1001, 0x110, 0x1001, 0x110, 0x1001); - ip_header.set_source(source); - assert_eq!(ip_header.get_source(), source); - - let dest = Ipv6Addr::new(0x110, 0x1001, 0x110, 0x1001, 0x110, 0x1001, 0x110, 0x1001); - ip_header.set_destination(dest); - assert_eq!(ip_header.get_destination(), dest); - - let mut pos = { - let mut hopopt = MutableHopByHopPacket::new(ip_header.payload_mut()).unwrap(); +impl Ipv6Packet { + pub fn total_len(&self) -> usize { + IPV6_HEADER_LEN + + self.extensions.iter().map(|ext| ext.len()).sum::() + + self.payload.len() + } + pub fn get_extension(&self, kind: ExtensionHeaderType) -> Option<&Ipv6ExtensionHeader> { + self.extensions.iter().find(|ext| ext.kind() == kind) + } +} - hopopt.set_next_header(IpNextLevelProtocol::Ipv6Opts); - assert_eq!(hopopt.get_next_header(), IpNextLevelProtocol::Ipv6Opts); +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExtensionHeaderType { + HopByHop, + Destination, + Routing, + Fragment, + Unknown(u8), +} - hopopt.set_hdr_ext_len(1); - assert_eq!(hopopt.get_hdr_ext_len(), 1); +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Ipv6ExtensionHeader { + HopByHop { next: IpNextProtocol, data: Bytes }, + Destination { next: IpNextProtocol, data: Bytes }, + Routing { next: IpNextProtocol, routing_type: u8, segments_left: u8, data: Bytes }, + Fragment { next: IpNextProtocol, offset: u16, more: bool, id: u32 }, + Raw { next: IpNextProtocol, raw: Bytes }, +} - hopopt.set_options(&[b'A'; 14][..]); - assert_eq!(hopopt.payload(), b"AAAAAAAAAAAAAA"); +impl Ipv6ExtensionHeader { + pub fn next_protocol(&self) -> IpNextProtocol { + match self { + Ipv6ExtensionHeader::HopByHop { next, .. } => *next, + Ipv6ExtensionHeader::Destination { next, .. } => *next, + Ipv6ExtensionHeader::Routing { next, .. } => *next, + Ipv6ExtensionHeader::Fragment { next, .. } => *next, + Ipv6ExtensionHeader::Raw { next, .. } => *next, + } + } + pub fn len(&self) -> usize { + match self { + Ipv6ExtensionHeader::HopByHop { data, .. } + | Ipv6ExtensionHeader::Destination { data, .. } => { + let base = 2 + data.len(); + (base + 7) / 8 * 8 // padding to multiple of 8 + } + Ipv6ExtensionHeader::Routing { data, .. } => { + let base = 4 + data.len(); + (base + 7) / 8 * 8 + } + Ipv6ExtensionHeader::Fragment { .. } => 8, + Ipv6ExtensionHeader::Raw { raw, .. } => raw.len(), + } + } + pub fn kind(&self) -> ExtensionHeaderType { + match self { + Ipv6ExtensionHeader::HopByHop { .. } => ExtensionHeaderType::HopByHop, + Ipv6ExtensionHeader::Destination { .. } => ExtensionHeaderType::Destination, + Ipv6ExtensionHeader::Routing { .. } => ExtensionHeaderType::Routing, + Ipv6ExtensionHeader::Fragment { .. } => ExtensionHeaderType::Fragment, + Ipv6ExtensionHeader::Raw { raw, .. } => { + // Even for Raw we can read the first byte to guess the kind + let kind = raw.get(0).copied().unwrap_or(0xff); + match kind { + 0 => ExtensionHeaderType::HopByHop, + 43 => ExtensionHeaderType::Routing, + 44 => ExtensionHeaderType::Fragment, + 60 => ExtensionHeaderType::Destination, + other => ExtensionHeaderType::Unknown(other), + } + } + } + } +} - hopopt.packet_size() +#[cfg(test)] +mod tests { + use super::*; + use crate::ip::IpNextProtocol; + use std::net::Ipv6Addr; + + #[test] + fn test_ipv6_basic_header_fields() { + let header = Ipv6Header { + version: 6, + traffic_class: 0xaa, + flow_label: 0x12345, + payload_length: 0, + next_header: IpNextProtocol::Udp, + hop_limit: 64, + source: Ipv6Addr::LOCALHOST, + destination: Ipv6Addr::UNSPECIFIED, }; - pos += { - let mut dstopt = - MutableDestinationPacket::new(&mut ip_header.payload_mut()[pos..]).unwrap(); - - dstopt.set_next_header(IpNextLevelProtocol::Ipv6Route); - assert_eq!(dstopt.get_next_header(), IpNextLevelProtocol::Ipv6Route); + let packet = Ipv6Packet { + header: header.clone(), + extensions: vec![], + payload: Bytes::new(), + }; - dstopt.set_hdr_ext_len(1); - assert_eq!(dstopt.get_hdr_ext_len(), 1); + assert_eq!(packet.header.version, 6); + assert_eq!(packet.header.traffic_class, 0xaa); + assert_eq!(packet.header.flow_label, 0x12345); + assert_eq!(packet.header.payload_length, 0); + assert_eq!(packet.header.next_header, IpNextProtocol::Udp); + assert_eq!(packet.header.hop_limit, 64); + assert_eq!(packet.header.source, Ipv6Addr::LOCALHOST); + assert_eq!(packet.header.destination, Ipv6Addr::UNSPECIFIED); + + let raw = packet.to_bytes(); + assert_eq!(raw.len(), IPV6_HEADER_LEN); + let reparsed = Ipv6Packet::from_bytes(raw.clone()).unwrap(); + assert_eq!(reparsed.header, packet.header); + } - dstopt.set_options(&[b'B'; 14][..]); - assert_eq!(dstopt.payload(), b"BBBBBBBBBBBBBB"); + #[test] + fn test_ipv6_from_bytes_parsing() { + use bytes::Bytes; + + let raw_bytes = Bytes::from_static(&[ + // Version(6), Traffic Class(0xa), Flow Label(0x12345) + 0x60, 0xA1, 0x23, 0x45, + // Payload Length: 8 bytes + 0x00, 0x08, + // Next Header: TCP (6) + 0x06, + // Hop Limit + 0x40, + // Source IP + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x1a, 0x2b, 0xff, 0xfe, 0x1a, 0x2b, 0x3c, + // Destination IP + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + // Payload (dummy 8 bytes) + b'H', b'e', b'l', b'l', b'o', b'!', b'!', b'\n', + ]); + + let parsed = Ipv6Packet::from_bytes(raw_bytes.clone()).expect("should parse successfully"); + + assert_eq!(parsed.header.version, 6); + assert_eq!(parsed.header.traffic_class, 0xa); + assert_eq!(parsed.header.flow_label, 0x12345); + assert_eq!(parsed.header.payload_length, 8); + assert_eq!(parsed.header.next_header, IpNextProtocol::Tcp); + assert_eq!(parsed.header.hop_limit, 0x40); + assert_eq!( + parsed.header.source, + "fe80::21a:2bff:fe1a:2b3c".parse::().unwrap() + ); + assert_eq!( + parsed.header.destination, + "ff02::2".parse::().unwrap() + ); + assert_eq!(&parsed.payload[..], b"Hello!!\n"); + assert_eq!(parsed.extensions.len(), 0); + assert_eq!(parsed.to_bytes(), raw_bytes); + } - dstopt.packet_size() + #[test] + fn test_ipv6_payload_roundtrip() { + use bytes::Bytes; + + let payload = Bytes::from_static(b"HELLO_WORLDS"); + let packet = Ipv6Packet { + header: super::Ipv6Header { + version: 6, + traffic_class: 0, + flow_label: 0, + payload_length: payload.len() as u16, + next_header: IpNextProtocol::Tcp, + hop_limit: 32, + source: Ipv6Addr::LOCALHOST, + destination: Ipv6Addr::LOCALHOST, + }, + extensions: vec![], + payload: payload.clone(), }; - pos += { - let mut routing = - MutableRoutingPacket::new(&mut ip_header.payload_mut()[pos..]).unwrap(); + let raw = packet.to_bytes(); + let parsed = Ipv6Packet::from_bytes(raw.clone()).unwrap(); - routing.set_next_header(IpNextLevelProtocol::Ipv6Frag); - assert_eq!(routing.get_next_header(), IpNextLevelProtocol::Ipv6Frag); + assert_eq!(parsed.header, packet.header); + assert_eq!(parsed.payload, payload); + assert_eq!(raw.len(), packet.total_len()); + } - routing.set_hdr_ext_len(1); - assert_eq!(routing.get_hdr_ext_len(), 1); + #[test] + fn test_ipv6_truncated_packet_rejected() { + use bytes::Bytes; - routing.set_routing_type(4); - assert_eq!(routing.get_routing_type(), 4); + let short = Bytes::from_static(&[0u8; 20]); // insufficient + assert!(Ipv6Packet::from_bytes(short).is_none()); + } - routing.set_segments_left(2); - assert_eq!(routing.get_segments_left(), 2); + #[test] + fn test_ipv6_total_len_computation() { + use bytes::Bytes; - routing.set_data(&[b'C'; 12][..]); - assert_eq!(routing.payload(), b"CCCCCCCCCCCC"); + let ext = Ipv6ExtensionHeader::Fragment { + next: IpNextProtocol::Tcp, + offset: 1, + more: true, + id: 42, + }; - routing.packet_size() + let packet = Ipv6Packet { + header: Ipv6Header { + version: 6, + traffic_class: 0, + flow_label: 0, + payload_length: 8, + next_header: IpNextProtocol::Tcp, + hop_limit: 1, + source: Ipv6Addr::LOCALHOST, + destination: Ipv6Addr::LOCALHOST, + }, + extensions: vec![ext], + payload: Bytes::from_static(b"ABCDEFGH"), }; - pos += { - let mut frag = MutableFragmentPacket::new(&mut ip_header.payload_mut()[pos..]).unwrap(); + let expected_len = IPV6_HEADER_LEN + 8 + 8; // header + fragment ext + payload + assert_eq!(packet.total_len(), expected_len); + assert_eq!(packet.to_bytes().len(), expected_len); + } - frag.set_next_header(IpNextLevelProtocol::Udp); - assert_eq!(frag.get_next_header(), IpNextLevelProtocol::Udp); + #[test] + fn test_extension_kind_known_variants() { + let hop = Ipv6ExtensionHeader::HopByHop { + next: IpNextProtocol::Tcp, + data: Bytes::from_static(&[1, 2, 3, 4]), + }; + assert_eq!(hop.kind(), ExtensionHeaderType::HopByHop); - frag.set_fragment_offset(1024); - assert_eq!(frag.get_fragment_offset(), 1024); + let dst = Ipv6ExtensionHeader::Destination { + next: IpNextProtocol::Udp, + data: Bytes::from_static(&[9, 8, 7]), + }; + assert_eq!(dst.kind(), ExtensionHeaderType::Destination); - frag.set_last_fragment(false); - assert!(!frag.is_last_fragment()); + let route = Ipv6ExtensionHeader::Routing { + next: IpNextProtocol::Tcp, + routing_type: 0, + segments_left: 0, + data: Bytes::from_static(&[1, 2, 3]), + }; + assert_eq!(route.kind(), ExtensionHeaderType::Routing); - frag.set_id(1234); - assert_eq!(frag.get_id(), 1234); + let frag = Ipv6ExtensionHeader::Fragment { + next: IpNextProtocol::Udp, + offset: 0, + more: false, + id: 12345, + }; + assert_eq!(frag.kind(), ExtensionHeaderType::Fragment); + } - frag.packet_size() + #[test] + fn test_extension_kind_raw_known() { + let raw_routing = Ipv6ExtensionHeader::Raw { + next: IpNextProtocol::new(43), + raw: Bytes::from_static(&[43, 1, 2, 3]), }; + assert_eq!(raw_routing.kind(), ExtensionHeaderType::Routing); - assert_eq!( - ExtensionIterable::new(&ip_header.payload()[..pos]) - .map(|ext| ( - ext.get_next_header(), - ext.get_hdr_ext_len(), - ext.packet_size() - )) - .collect::>(), - vec![ - (IpNextLevelProtocol::Ipv6Opts, 1, 16), - (IpNextLevelProtocol::Ipv6Route, 1, 16), - (IpNextLevelProtocol::Ipv6Frag, 1, 16), - (IpNextLevelProtocol::Udp, 0, 8), - ] - ); + let raw_frag = Ipv6ExtensionHeader::Raw { + next: IpNextProtocol::new(44), + raw: Bytes::from_static(&[44, 0, 0, 0]), + }; + assert_eq!(raw_frag.kind(), ExtensionHeaderType::Fragment); } - let ref_packet = [ - 0x61, /* ver/traffic class */ - 0x11, /* traffic class/flow label */ - 0x01, 0x01, /* flow label */ - 0x01, 0x01, /* payload length */ - 0x00, /* next header */ - 0x01, /* hop limit */ - /* source ip */ - 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, - 0x01, /* dest ip */ - 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, 0x01, 0x01, 0x10, 0x10, - 0x01, /* Hop-by-Hop Options */ - 0x3c, // Next Header - 0x01, // Hdr Ext Len - b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', b'A', - /* Destination Options */ - 0x2b, // Next Header - 0x01, // Hdr Ext Len - b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', b'B', - /* Routing */ - 0x2c, // Next Header - 0x01, // Hdr Ext Len - 0x04, // Routing Type - 0x02, // Segments Left - b'C', b'C', b'C', b'C', b'C', b'C', b'C', b'C', b'C', b'C', b'C', b'C', - /* Fragment */ - 0x11, // Next Header - 0x00, // Reserved - 0x04, 0x01, // Fragment Offset - 0x00, 0x00, 0x04, 0xd2, // Identification - ]; - assert_eq!(&ref_packet[..], &packet[..ref_packet.len()]); + #[test] + fn test_extension_kind_raw_unknown() { + let raw_unknown = Ipv6ExtensionHeader::Raw { + next: IpNextProtocol::new(250), + raw: Bytes::from_static(&[250, 0, 1, 2]), + }; + assert_eq!(raw_unknown.kind(), ExtensionHeaderType::Unknown(250)); + } } diff --git a/nex-packet/src/lib.rs b/nex-packet/src/lib.rs index 5412cc3..a9ae9b9 100644 --- a/nex-packet/src/lib.rs +++ b/nex-packet/src/lib.rs @@ -1,35 +1,19 @@ -//! Support for packet parsing and manipulation. Enables users to work with packets at a granular level. - -#![allow(missing_docs)] -#![deny(warnings)] -#![macro_use] - -extern crate alloc; - -#[cfg(test)] -extern crate std; - -extern crate nex_core; -extern crate nex_macro; -extern crate nex_macro_helper; - -pub use nex_macro_helper::packet::*; - -pub mod arp; -pub mod dhcp; -pub mod dns; +pub mod packet; pub mod ethernet; -pub mod frame; -pub mod gre; -pub mod icmp; -pub mod icmpv6; +pub mod arp; pub mod ip; pub mod ipv4; pub mod ipv6; -pub mod sll; -pub mod sll2; +pub mod util; +pub mod icmp; +pub mod icmpv6; pub mod tcp; pub mod udp; -pub mod usbpcap; -pub mod util; pub mod vlan; +pub mod dhcp; +pub mod dns; +pub mod gre; +pub mod vxlan; +pub mod flowcontrol; +pub mod frame; +pub mod builder; diff --git a/nex-packet/src/packet.rs b/nex-packet/src/packet.rs new file mode 100644 index 0000000..93e1059 --- /dev/null +++ b/nex-packet/src/packet.rs @@ -0,0 +1,50 @@ +use bytes::{Bytes, BytesMut}; + +/// Represents a generic network packet. +pub trait Packet: Sized { + type Header; + + /// Parse from a byte slice. + fn from_buf(buf: &[u8]) -> Option; + + /// Parse from raw bytes. (with ownership) + fn from_bytes(bytes: Bytes) -> Option; + + /// Serialize into raw bytes. + fn to_bytes(&self) -> Bytes; + + /// Get the header of the packet. + fn header(&self) -> Bytes; + + /// Get the payload of the packet. + fn payload(&self) -> Bytes; + + /// Get the length of the header. + fn header_len(&self) -> usize; + + /// Get the length of the payload. + fn payload_len(&self) -> usize; + /// Get the total length of the packet (header + payload). + fn total_len(&self) -> usize; + /// Convert the packet to a mutable byte buffer. + fn to_bytes_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.total_len()); + buf.extend_from_slice(&self.to_bytes()); + buf + } + /// Get a mutable byte buffer for the header. + fn header_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.header_len()); + buf.extend_from_slice(&self.header()); + buf + } + /// Get a mutable byte buffer for the payload. + fn payload_mut(&self) -> BytesMut { + let mut buf = BytesMut::with_capacity(self.payload_len()); + buf.extend_from_slice(&self.payload()); + buf + } + + fn into_parts(self) -> (Self::Header, Bytes); + +} diff --git a/nex-packet/src/sll.rs b/nex-packet/src/sll.rs deleted file mode 100644 index 77cf389..0000000 --- a/nex-packet/src/sll.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! A Linux cooked-mode capture (LINKTYPE_LINUX_SLL) packet abstraction. - -use alloc::vec::Vec; - -use super::ethernet::EtherType; -use nex_macro::packet; -use nex_macro_helper::types::*; - -/// Represents an SLL packet (LINKTYPE_LINUX_SLL). -#[packet] -pub struct SLL { - #[construct_with(u16)] - pub packet_type: u16be, - #[construct_with(u16)] - pub link_layer_address_type: u16be, - #[construct_with(u16)] - pub link_layer_address_len: u16be, - #[construct_with(u8, u8, u8, u8, u8, u8, u8, u8)] - #[length = "8"] - pub link_layer_address: Vec, - #[construct_with(u16)] - pub protocol: EtherType, - #[payload] - pub payload: Vec, -} diff --git a/nex-packet/src/sll2.rs b/nex-packet/src/sll2.rs deleted file mode 100644 index c9aa5f6..0000000 --- a/nex-packet/src/sll2.rs +++ /dev/null @@ -1,37 +0,0 @@ -//! A Linux cooked-mode capture v2 (LINKTYPE_LINUX_SLL2) packet abstraction. - -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; - -use super::ethernet::EtherType; - -/// Represents an SLL2 packet (LINKTYPE_LINUX_SLL2). -#[packet] -pub struct SLL2 { - #[construct_with(u16)] - pub protocol_type: EtherType, - - #[construct_with(u16)] - pub reserved: u16be, - - #[construct_with(u32)] - pub interface_index: u32be, - - #[construct_with(u16)] - pub arphrd_type: u16be, - - #[construct_with(u8)] - pub packet_type: u8, - - #[construct_with(u8)] - pub link_layer_address_length: u8, - - #[construct_with(u8, u8, u8, u8, u8, u8, u8, u8)] - #[length = "8"] - pub link_layer_address: Vec, - - #[payload] - pub payload: Vec, -} diff --git a/nex-packet/src/tcp.rs b/nex-packet/src/tcp.rs index d6e16a7..5ba8592 100644 --- a/nex-packet/src/tcp.rs +++ b/nex-packet/src/tcp.rs @@ -1,23 +1,19 @@ //! A TCP packet abstraction. -use crate::ip::IpNextLevelProtocol; -use crate::Packet; -use crate::PrimitiveValues; - -use alloc::{vec, vec::Vec}; - -use nex_macro::packet; -use nex_macro_helper::types::*; +use crate::ip::IpNextProtocol; +use crate::packet::Packet; use crate::util::{self, Octets}; -use std::net::Ipv4Addr; +use std::net::{IpAddr, Ipv4Addr}; use std::net::Ipv6Addr; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use nex_core::bitfield::{u16be, u32be, u4}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// Minimum TCP Header Length -pub const TCP_HEADER_LEN: usize = MutableTcpPacket::minimum_packet_size(); +pub const TCP_HEADER_LEN: usize = 20; /// Minimum TCP Data Offset pub const TCP_MIN_DATA_OFFSET: u8 = 5; /// Maximum TCP Option Length @@ -25,170 +21,6 @@ pub const TCP_OPTION_MAX_LEN: usize = 40; /// Maximum TCP Header Length (with options) pub const TCP_HEADER_MAX_LEN: usize = TCP_HEADER_LEN + TCP_OPTION_MAX_LEN; -/// Represents the TCP option header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct TcpOptionHeader { - pub kind: TcpOptionKind, - pub length: Option, - pub data: Vec, -} - -impl TcpOptionHeader { - /// Get the timestamp of the TCP option - pub fn get_timestamp(&self) -> (u32, u32) { - if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 { - let mut my: [u8; 4] = [0; 4]; - my.copy_from_slice(&self.data[0..4]); - let mut their: [u8; 4] = [0; 4]; - their.copy_from_slice(&self.data[4..8]); - (u32::from_be_bytes(my), u32::from_be_bytes(their)) - } else { - return (0, 0); - } - } - /// Get the MSS of the TCP option - pub fn get_mss(&self) -> u16 { - if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 { - let mut mss: [u8; 2] = [0; 2]; - mss.copy_from_slice(&self.data[0..2]); - u16::from_be_bytes(mss) - } else { - 0 - } - } - /// Get the WSCALE of the TCP option - pub fn get_wscale(&self) -> u8 { - if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 { - self.data[0] - } else { - 0 - } - } -} - -/// Represents the TCP header. -#[derive(Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct TcpHeader { - pub source: u16be, - pub destination: u16be, - pub sequence: u32be, - pub acknowledgement: u32be, - pub data_offset: u4, - pub reserved: u4, - pub flags: u8, - pub window: u16be, - pub checksum: u16be, - pub urgent_ptr: u16be, - pub options: Vec, -} - -impl TcpHeader { - /// Construct a TCP header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < TCP_HEADER_LEN { - return Err("Packet is too small for TCP header".to_string()); - } - match TcpPacket::new(packet) { - Some(tcp_packet) => Ok(TcpHeader { - source: tcp_packet.get_source(), - destination: tcp_packet.get_destination(), - sequence: tcp_packet.get_sequence(), - acknowledgement: tcp_packet.get_acknowledgement(), - data_offset: tcp_packet.get_data_offset(), - reserved: tcp_packet.get_reserved(), - flags: tcp_packet.get_flags(), - window: tcp_packet.get_window(), - checksum: tcp_packet.get_checksum(), - urgent_ptr: tcp_packet.get_urgent_ptr(), - options: tcp_packet - .get_options_iter() - .map(|opt| TcpOptionHeader { - kind: opt.get_kind(), - length: opt.get_length_raw().first().cloned(), - data: opt.payload().to_vec(), - }) - .collect(), - }), - None => Err("Failed to parse TCP packet".to_string()), - } - } - /// Construct a TCP header from a TcpPacket. - pub(crate) fn from_packet(tcp_packet: &TcpPacket) -> TcpHeader { - TcpHeader { - source: tcp_packet.get_source(), - destination: tcp_packet.get_destination(), - sequence: tcp_packet.get_sequence(), - acknowledgement: tcp_packet.get_acknowledgement(), - data_offset: tcp_packet.get_data_offset(), - reserved: tcp_packet.get_reserved(), - flags: tcp_packet.get_flags(), - window: tcp_packet.get_window(), - checksum: tcp_packet.get_checksum(), - urgent_ptr: tcp_packet.get_urgent_ptr(), - options: tcp_packet - .get_options_iter() - .map(|opt| TcpOptionHeader { - kind: opt.get_kind(), - length: opt.get_length_raw().first().cloned(), - data: opt.payload().to_vec(), - }) - .collect(), - } - } -} - -/// Represents the TCP Flags -/// -#[allow(non_snake_case)] -#[allow(non_upper_case_globals)] -pub mod TcpFlags { - /// CWR – Congestion Window Reduced (CWR) flag is set by the sending - /// host to indicate that it received a TCP segment with the ECE flag set - /// and had responded in congestion control mechanism. - pub const CWR: u8 = 0b10000000; - /// ECE – ECN-Echo has a dual role, depending on the value of the - /// SYN flag. It indicates: - /// If the SYN flag is set (1), that the TCP peer is ECN capable. - /// If the SYN flag is clear (0), that a packet with Congestion Experienced - /// flag set (ECN=11) in IP header received during normal transmission. - pub const ECE: u8 = 0b01000000; - /// URG – indicates that the Urgent pointer field is significant. - pub const URG: u8 = 0b00100000; - /// ACK – indicates that the Acknowledgment field is significant. - /// All packets after the initial SYN packet sent by the client should have this flag set. - pub const ACK: u8 = 0b00010000; - /// PSH – Push function. Asks to push the buffered data to the receiving application. - pub const PSH: u8 = 0b00001000; - /// RST – Reset the connection. - pub const RST: u8 = 0b00000100; - /// SYN – Synchronize sequence numbers. Only the first packet sent from each end - /// should have this flag set. - pub const SYN: u8 = 0b00000010; - /// FIN – No more data from sender. - pub const FIN: u8 = 0b00000001; -} - -/// Represents a TCP packet. -#[packet] -pub struct Tcp { - pub source: u16be, - pub destination: u16be, - pub sequence: u32be, - pub acknowledgement: u32be, - pub data_offset: u4, - pub reserved: u4, - pub flags: u8, - pub window: u16be, - pub checksum: u16be, - pub urgent_ptr: u16be, - #[length_fn = "tcp_options_length"] - pub options: Vec, - #[payload] - pub payload: Vec, -} - /// Represents a TCP Option Kind. /// #[allow(non_camel_case_types)] @@ -286,50 +118,97 @@ impl TcpOptionKind { _ => TcpOptionKind::RESERVED(n), } } + /// Get the name of the TCP option kind. - pub fn name(&self) -> String { + pub fn name(&self) -> &'static str { match *self { - TcpOptionKind::EOL => String::from("EOL"), - TcpOptionKind::NOP => String::from("NOP"), - TcpOptionKind::MSS => String::from("MSS"), - TcpOptionKind::WSCALE => String::from("WSCALE"), - TcpOptionKind::SACK_PERMITTED => String::from("SACK_PERMITTED"), - TcpOptionKind::SACK => String::from("SACK"), - TcpOptionKind::ECHO => String::from("ECHO"), - TcpOptionKind::ECHO_REPLY => String::from("ECHO_REPLY"), - TcpOptionKind::TIMESTAMPS => String::from("TIMESTAMPS"), - TcpOptionKind::POCP => String::from("POCP"), - TcpOptionKind::POSP => String::from("POSP"), - TcpOptionKind::CC => String::from("CC"), - TcpOptionKind::CC_NEW => String::from("CC_NEW"), - TcpOptionKind::CC_ECHO => String::from("CC_ECHO"), - TcpOptionKind::ALT_CHECKSUM_REQ => String::from("ALT_CHECKSUM_REQ"), - TcpOptionKind::ALT_CHECKSUM_DATA => String::from("ALT_CHECKSUM_DATA"), - TcpOptionKind::SKEETER => String::from("SKEETER"), - TcpOptionKind::BUBBA => String::from("BUBBA"), - TcpOptionKind::TRAILER_CHECKSUM => String::from("TRAILER_CHECKSUM"), - TcpOptionKind::MD5_SIGNATURE => String::from("MD5_SIGNATURE"), - TcpOptionKind::SCPS_CAPABILITIES => String::from("SCPS_CAPABILITIES"), - TcpOptionKind::SELECTIVE_ACK => String::from("SELECTIVE_ACK"), - TcpOptionKind::RECORD_BOUNDARIES => String::from("RECORD_BOUNDARIES"), - TcpOptionKind::CORRUPTION_EXPERIENCED => String::from("CORRUPTION_EXPERIENCED"), - TcpOptionKind::SNAP => String::from("SNAP"), - TcpOptionKind::UNASSIGNED => String::from("UNASSIGNED"), - TcpOptionKind::TCP_COMPRESSION_FILTER => String::from("TCP_COMPRESSION_FILTER"), - TcpOptionKind::QUICK_START => String::from("QUICK_START"), - TcpOptionKind::USER_TIMEOUT => String::from("USER_TIMEOUT"), - TcpOptionKind::TCP_AO => String::from("TCP_AO"), - TcpOptionKind::MPTCP => String::from("MPTCP"), - TcpOptionKind::RESERVED_31 => String::from("RESERVED_31"), - TcpOptionKind::RESERVED_32 => String::from("RESERVED_32"), - TcpOptionKind::RESERVED_33 => String::from("RESERVED_33"), - TcpOptionKind::FAST_OPEN_COOKIE => String::from("FAST_OPEN_COOKIE"), - TcpOptionKind::TCP_ENO => String::from("TCP_ENO"), - TcpOptionKind::ACC_ECNO_0 => String::from("ACC_ECNO_0"), - TcpOptionKind::ACC_ECNO_1 => String::from("ACC_ECNO_1"), - TcpOptionKind::EXPERIMENT_1 => String::from("EXPERIMENT_1"), - TcpOptionKind::EXPERIMENT_2 => String::from("EXPERIMENT_2"), - TcpOptionKind::RESERVED(n) => format!("RESERVED_{}", n), + TcpOptionKind::EOL => "EOL", + TcpOptionKind::NOP => "NOP", + TcpOptionKind::MSS => "MSS", + TcpOptionKind::WSCALE => "WSCALE", + TcpOptionKind::SACK_PERMITTED => "SACK_PERMITTED", + TcpOptionKind::SACK => "SACK", + TcpOptionKind::ECHO => "ECHO", + TcpOptionKind::ECHO_REPLY => "ECHO_REPLY", + TcpOptionKind::TIMESTAMPS => "TIMESTAMPS", + TcpOptionKind::POCP => "POCP", + TcpOptionKind::POSP => "POSP", + TcpOptionKind::CC => "CC", + TcpOptionKind::CC_NEW => "CC_NEW", + TcpOptionKind::CC_ECHO => "CC_ECHO", + TcpOptionKind::ALT_CHECKSUM_REQ => "ALT_CHECKSUM_REQ", + TcpOptionKind::ALT_CHECKSUM_DATA => "ALT_CHECKSUM_DATA", + TcpOptionKind::SKEETER => "SKEETER", + TcpOptionKind::BUBBA => "BUBBA", + TcpOptionKind::TRAILER_CHECKSUM => "TRAILER_CHECKSUM", + TcpOptionKind::MD5_SIGNATURE => "MD5_SIGNATURE", + TcpOptionKind::SCPS_CAPABILITIES => "SCPS_CAPABILITIES", + TcpOptionKind::SELECTIVE_ACK => "SELECTIVE_ACK", + TcpOptionKind::RECORD_BOUNDARIES => "RECORD_BOUNDARIES", + TcpOptionKind::CORRUPTION_EXPERIENCED => "CORRUPTION_EXPERIENCED", + TcpOptionKind::SNAP => "SNAP", + TcpOptionKind::UNASSIGNED => "UNASSIGNED", + TcpOptionKind::TCP_COMPRESSION_FILTER => "TCP_COMPRESSION_FILTER", + TcpOptionKind::QUICK_START => "QUICK_START", + TcpOptionKind::USER_TIMEOUT => "USER_TIMEOUT", + TcpOptionKind::TCP_AO => "TCP_AO", + TcpOptionKind::MPTCP => "MPTCP", + TcpOptionKind::RESERVED_31 => "RESERVED_31", + TcpOptionKind::RESERVED_32 => "RESERVED_32", + TcpOptionKind::RESERVED_33 => "RESERVED_33", + TcpOptionKind::FAST_OPEN_COOKIE => "FAST_OPEN_COOKIE", + TcpOptionKind::TCP_ENO => "TCP_ENO", + TcpOptionKind::ACC_ECNO_0 => "ACC_ECNO_0", + TcpOptionKind::ACC_ECNO_1 => "ACC_ECNO_1", + TcpOptionKind::EXPERIMENT_1 => "EXPERIMENT_1", + TcpOptionKind::EXPERIMENT_2 => "EXPERIMENT_2", + TcpOptionKind::RESERVED(_) => "RESERVED", + } + } + /// Get the value of the TCP option kind. + pub fn value(&self) -> u8 { + match *self { + TcpOptionKind::EOL => 0, + TcpOptionKind::NOP => 1, + TcpOptionKind::MSS => 2, + TcpOptionKind::WSCALE => 3, + TcpOptionKind::SACK_PERMITTED => 4, + TcpOptionKind::SACK => 5, + TcpOptionKind::ECHO => 6, + TcpOptionKind::ECHO_REPLY => 7, + TcpOptionKind::TIMESTAMPS => 8, + TcpOptionKind::POCP => 9, + TcpOptionKind::POSP => 10, + TcpOptionKind::CC => 11, + TcpOptionKind::CC_NEW => 12, + TcpOptionKind::CC_ECHO => 13, + TcpOptionKind::ALT_CHECKSUM_REQ => 14, + TcpOptionKind::ALT_CHECKSUM_DATA => 15, + TcpOptionKind::SKEETER => 16, + TcpOptionKind::BUBBA => 17, + TcpOptionKind::TRAILER_CHECKSUM => 18, + TcpOptionKind::MD5_SIGNATURE => 19, + TcpOptionKind::SCPS_CAPABILITIES => 20, + TcpOptionKind::SELECTIVE_ACK => 21, + TcpOptionKind::RECORD_BOUNDARIES => 22, + TcpOptionKind::CORRUPTION_EXPERIENCED => 23, + TcpOptionKind::SNAP => 24, + TcpOptionKind::UNASSIGNED => 25, + TcpOptionKind::TCP_COMPRESSION_FILTER => 26, + TcpOptionKind::QUICK_START => 27, + TcpOptionKind::USER_TIMEOUT => 28, + TcpOptionKind::TCP_AO => 29, + TcpOptionKind::MPTCP => 30, + TcpOptionKind::RESERVED_31 => 31, + TcpOptionKind::RESERVED_32 => 32, + TcpOptionKind::RESERVED_33 => 33, + TcpOptionKind::FAST_OPEN_COOKIE => 34, + TcpOptionKind::TCP_ENO => 69, + TcpOptionKind::ACC_ECNO_0 => 172, + TcpOptionKind::ACC_ECNO_1 => 174, + TcpOptionKind::EXPERIMENT_1 => 253, + TcpOptionKind::EXPERIMENT_2 => 254, + TcpOptionKind::RESERVED(n) => n, } } /// Get size (bytes) of the TCP option. @@ -357,76 +236,94 @@ impl TcpOptionKind { } } -impl PrimitiveValues for TcpOptionKind { - type T = (u8,); - fn to_primitive_values(&self) -> (u8,) { - match *self { - TcpOptionKind::EOL => (0,), - TcpOptionKind::NOP => (1,), - TcpOptionKind::MSS => (2,), - TcpOptionKind::WSCALE => (3,), - TcpOptionKind::SACK_PERMITTED => (4,), - TcpOptionKind::SACK => (5,), - TcpOptionKind::ECHO => (6,), - TcpOptionKind::ECHO_REPLY => (7,), - TcpOptionKind::TIMESTAMPS => (8,), - TcpOptionKind::POCP => (9,), - TcpOptionKind::POSP => (10,), - TcpOptionKind::CC => (11,), - TcpOptionKind::CC_NEW => (12,), - TcpOptionKind::CC_ECHO => (13,), - TcpOptionKind::ALT_CHECKSUM_REQ => (14,), - TcpOptionKind::ALT_CHECKSUM_DATA => (15,), - TcpOptionKind::SKEETER => (16,), - TcpOptionKind::BUBBA => (17,), - TcpOptionKind::TRAILER_CHECKSUM => (18,), - TcpOptionKind::MD5_SIGNATURE => (19,), - TcpOptionKind::SCPS_CAPABILITIES => (20,), - TcpOptionKind::SELECTIVE_ACK => (21,), - TcpOptionKind::RECORD_BOUNDARIES => (22,), - TcpOptionKind::CORRUPTION_EXPERIENCED => (23,), - TcpOptionKind::SNAP => (24,), - TcpOptionKind::UNASSIGNED => (25,), - TcpOptionKind::TCP_COMPRESSION_FILTER => (26,), - TcpOptionKind::QUICK_START => (27,), - TcpOptionKind::USER_TIMEOUT => (28,), - TcpOptionKind::TCP_AO => (29,), - TcpOptionKind::MPTCP => (30,), - TcpOptionKind::RESERVED_31 => (31,), - TcpOptionKind::RESERVED_32 => (32,), - TcpOptionKind::RESERVED_33 => (33,), - TcpOptionKind::FAST_OPEN_COOKIE => (34,), - TcpOptionKind::TCP_ENO => (35,), - TcpOptionKind::ACC_ECNO_0 => (36,), - TcpOptionKind::ACC_ECNO_1 => (37,), - TcpOptionKind::EXPERIMENT_1 => (253,), - TcpOptionKind::EXPERIMENT_2 => (254,), - TcpOptionKind::RESERVED(n) => (n,), +/// Represents the TCP Flags +/// +#[allow(non_snake_case)] +#[allow(non_upper_case_globals)] +pub mod TcpFlags { + /// CWR – Congestion Window Reduced (CWR) flag is set by the sending + /// host to indicate that it received a TCP segment with the ECE flag set + /// and had responded in congestion control mechanism. + pub const CWR: u8 = 0b10000000; + /// ECE – ECN-Echo has a dual role, depending on the value of the + /// SYN flag. It indicates: + /// If the SYN flag is set (1), that the TCP peer is ECN capable. + /// If the SYN flag is clear (0), that a packet with Congestion Experienced + /// flag set (ECN=11) in IP header received during normal transmission. + pub const ECE: u8 = 0b01000000; + /// URG – indicates that the Urgent pointer field is significant. + pub const URG: u8 = 0b00100000; + /// ACK – indicates that the Acknowledgment field is significant. + /// All packets after the initial SYN packet sent by the client should have this flag set. + pub const ACK: u8 = 0b00010000; + /// PSH – Push function. Asks to push the buffered data to the receiving application. + pub const PSH: u8 = 0b00001000; + /// RST – Reset the connection. + pub const RST: u8 = 0b00000100; + /// SYN – Synchronize sequence numbers. Only the first packet sent from each end + /// should have this flag set. + pub const SYN: u8 = 0b00000010; + /// FIN – No more data from sender. + pub const FIN: u8 = 0b00000001; +} + +/// Represents the TCP option header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct TcpOptionHeader { + pub kind: TcpOptionKind, + pub length: Option, + pub data: Bytes, +} + +impl TcpOptionHeader { + /// Get the timestamp of the TCP option + pub fn get_timestamp(&self) -> (u32, u32) { + if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 { + let mut my: [u8; 4] = [0; 4]; + my.copy_from_slice(&self.data[0..4]); + let mut their: [u8; 4] = [0; 4]; + their.copy_from_slice(&self.data[4..8]); + (u32::from_be_bytes(my), u32::from_be_bytes(their)) + } else { + return (0, 0); + } + } + /// Get the MSS of the TCP option + pub fn get_mss(&self) -> u16 { + if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 { + let mut mss: [u8; 2] = [0; 2]; + mss.copy_from_slice(&self.data[0..2]); + u16::from_be_bytes(mss) + } else { + 0 + } + } + /// Get the WSCALE of the TCP option + pub fn get_wscale(&self) -> u8 { + if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 { + self.data[0] + } else { + 0 } } } /// A TCP option. -#[packet] -pub struct TcpOption { - #[construct_with(u8)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TcpOptionPacket { kind: TcpOptionKind, - #[length_fn = "tcp_option_length"] - // The length field is an optional field, using a Vec is a way to implement - // it - length: Vec, - #[length_fn = "tcp_option_payload_length"] - #[payload] - data: Vec, + length: Option, + data: Bytes, } -impl TcpOption { +impl TcpOptionPacket { /// NOP: This may be used to align option fields on 32-bit boundaries for better performance. pub fn nop() -> Self { - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::NOP, - length: vec![], - data: vec![], + length: None, + data: Bytes::new(), } } @@ -434,37 +331,37 @@ impl TcpOption { /// packets were sent. TCP timestamps are not normally aligned to the system clock and /// start at some random value. pub fn timestamp(my: u32, their: u32) -> Self { - let mut data = vec![]; + let mut data = BytesMut::new(); data.extend_from_slice(&my.octets()[..]); data.extend_from_slice(&their.octets()[..]); - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::TIMESTAMPS, - length: vec![10], - data: data, + length: Some(10), + data: data.freeze(), } } /// MSS: The maximum segment size (MSS) is the largest amount of data, specified in bytes, /// that TCP is willing to receive in a single segment. pub fn mss(val: u16) -> Self { - let mut data = vec![]; + let mut data = BytesMut::new(); data.extend_from_slice(&val.octets()[..]); - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::MSS, - length: vec![4], - data: data, + length: Some(4), + data: data.freeze(), } } /// Window scale: The TCP window scale option, as defined in RFC 1323, is an option used to /// increase the maximum window size from 65,535 bytes to 1 gigabyte. pub fn wscale(val: u8) -> Self { - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::WSCALE, - length: vec![3], - data: vec![val], + length: Some(3), + data: Bytes::from(vec![val]), } } @@ -472,10 +369,10 @@ impl TcpOption { /// discontinuous blocks of packets which were received correctly. This options enables use of /// SACK during negotiation. pub fn sack_perm() -> Self { - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::SACK_PERMITTED, - length: vec![2], - data: vec![], + length: Some(2), + data: Bytes::new(), } } @@ -484,14 +381,14 @@ impl TcpOption { /// a number of SACK blocks, where each SACK block is conveyed by the starting and ending sequence /// numbers of a contiguous range that the receiver correctly received. pub fn selective_ack(acks: &[u32]) -> Self { - let mut data = vec![]; + let mut data = BytesMut::new(); for ack in acks { data.extend_from_slice(&ack.octets()[..]); } - TcpOption { + TcpOptionPacket { kind: TcpOptionKind::SACK, - length: vec![1 /* number */ + 1 /* length */ + data.len() as u8], - data: data, + length: Some(1 /* number */ + 1 /* length */ + data.len() as u8), + data: data.freeze(), } } /// Get the TCP option kind. @@ -500,10 +397,11 @@ impl TcpOption { } /// Get length of the TCP option. pub fn length(&self) -> u8 { - if self.length.is_empty() { - 0 + if let Some(len) = self.length { + len } else { - self.length[0] + // If length is None, it means the option has no length (like NOP). + 0 } } /// Get the timestamp of the TCP option @@ -538,33 +436,209 @@ impl TcpOption { } } -/// This function gets the 'length' of the length field of the IPv4Option packet -/// Few options (EOL, NOP) are 1 bytes long, and then have a length field equal -/// to 0. -#[inline] -fn tcp_option_length(option: &TcpOptionPacket) -> usize { - match option.get_kind() { - TcpOptionKind::EOL => 0, - TcpOptionKind::NOP => 0, - _ => 1, - } +/// Represents the TCP header. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct TcpHeader { + pub source: u16be, + pub destination: u16be, + pub sequence: u32be, + pub acknowledgement: u32be, + pub data_offset: u4, + pub reserved: u4, + pub flags: u8, + pub window: u16be, + pub checksum: u16be, + pub urgent_ptr: u16be, + pub options: Vec, } -fn tcp_option_payload_length(ipv4_option: &TcpOptionPacket) -> usize { - match ipv4_option.get_length_raw().first() { - Some(len) if *len >= 2 => *len as usize - 2, - _ => 0, +/// Represents a TCP packet. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TcpPacket { + pub header: TcpHeader, + pub payload: Bytes, +} + +impl Packet for TcpPacket { + type Header = TcpHeader; + + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < TCP_HEADER_LEN { + return None; + } + + let source = bytes.get_u16(); + let destination = bytes.get_u16(); + let sequence = bytes.get_u32(); + let acknowledgement = bytes.get_u32(); + + let offset_reserved = bytes.get_u8(); + let data_offset = offset_reserved >> 4; + let reserved = offset_reserved & 0x0F; + + let flags = bytes.get_u8(); + let window = bytes.get_u16(); + let checksum = bytes.get_u16(); + let urgent_ptr = bytes.get_u16(); + + let header_len = data_offset as usize * 4; + if header_len < TCP_HEADER_LEN || bytes.len() + 20 < header_len { + return None; + } + + let mut options = Vec::new(); + let options_len = header_len - TCP_HEADER_LEN; + let (mut options_bytes, rest) = bytes.split_at(options_len); + bytes = rest; + + while options_bytes.has_remaining() { + let kind = TcpOptionKind::new(options_bytes.get_u8()); + match kind { + TcpOptionKind::EOL => { + options.push(TcpOptionPacket { kind, length: None, data: Bytes::new() }); + break; + } + TcpOptionKind::NOP => { + options.push(TcpOptionPacket { kind, length: None, data: Bytes::new() }); + } + _ => { + if options_bytes.remaining() < 1 { + return None; + } + let len = options_bytes.get_u8(); + if len < 2 || (len as usize) > options_bytes.remaining() + 2 { + return None; + } + let data_len = (len - 2) as usize; + let (data_slice, rest) = options_bytes.split_at(data_len); + options_bytes = rest; + options.push(TcpOptionPacket { + kind, + length: Some(len), + data: Bytes::copy_from_slice(data_slice), + }); + } + } + } + + Some(TcpPacket { + header: TcpHeader { + source, + destination, + sequence, + acknowledgement, + data_offset: u4::from_be(data_offset), + reserved: u4::from_be(reserved), + flags, + window, + checksum, + urgent_ptr, + options, + }, + payload: Bytes::copy_from_slice(bytes), + }) + } + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut bytes = BytesMut::with_capacity(self.header_len() + self.payload.len()); + + bytes.put_u16(self.header.source); + bytes.put_u16(self.header.destination); + bytes.put_u32(self.header.sequence); + bytes.put_u32(self.header.acknowledgement); + + let offset_reserved = (self.header.data_offset.to_be() << 4) | (self.header.reserved.to_be() & 0x0F); + bytes.put_u8(offset_reserved); + + bytes.put_u8(self.header.flags); + bytes.put_u16(self.header.window); + bytes.put_u16(self.header.checksum); + bytes.put_u16(self.header.urgent_ptr); + + for option in &self.header.options { + bytes.put_u8(option.kind.value()); + if let Some(length) = option.length { + bytes.put_u8(length); + bytes.extend_from_slice(&option.data); + } + } + + // Padding to 4-byte alignment + while bytes.len() % 4 != 0 { + bytes.put_u8(0); + } + + bytes.extend_from_slice(&self.payload); + + bytes.freeze() + } + + fn header(&self) -> Bytes { + self.to_bytes().slice(..self.header_len()) + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + let base = TCP_HEADER_LEN; + let mut opt_len = 0; + + for opt in &self.header.options { + match opt.kind { + TcpOptionKind::EOL | TcpOptionKind::NOP => { + opt_len += 1; // EOL and NOP are one byte + } + _ => { + // kind(1B) + length(1B) + payload + if let Some(len) = opt.length { + opt_len += len as usize; + } else { + // Ensure at least 2 bytes (kind + length) + opt_len += 2; + } + } + } + } + + let total = base + opt_len; + // The TCP header is always rounded to a 4 byte boundary + (total + 3) & !0x03 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) } } -#[inline] -fn tcp_options_length(tcp: &TcpPacket) -> usize { - let data_offset = tcp.get_data_offset(); +impl TcpPacket { + pub fn tcp_options_length(&self) -> usize { + if self.header.data_offset > 5 { + self.header.data_offset as usize * 4 - 20 + } else { + 0 + } + } +} - if data_offset > 5 { - data_offset as usize * 4 - 20 - } else { - 0 +pub fn checksum(packet: &TcpPacket, source: &IpAddr, destination: &IpAddr) -> u16 { + match (source, destination) { + (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst), + (IpAddr::V6(src), IpAddr::V6(dst)) => ipv6_checksum(packet, src, dst), + _ => 0, // Unsupported IP version } } @@ -587,12 +661,12 @@ pub fn ipv4_checksum_adv( destination: &Ipv4Addr, ) -> u16 { util::ipv4_checksum( - packet.packet(), + &packet.to_bytes(), 8, extra_data, source, destination, - IpNextLevelProtocol::Tcp, + IpNextProtocol::Tcp, ) } @@ -615,157 +689,92 @@ pub fn ipv6_checksum_adv( destination: &Ipv6Addr, ) -> u16 { util::ipv6_checksum( - packet.packet(), + &packet.to_bytes(), 8, extra_data, source, destination, - IpNextLevelProtocol::Tcp, + IpNextProtocol::Tcp, ) } -#[test] -fn tcp_header_ipv4_test() { - use crate::ip::IpNextLevelProtocol; - use crate::ipv4::MutableIpv4Packet; - - const IPV4_HEADER_LEN: usize = 20; - const TCP_HEADER_LEN: usize = 32; - const TEST_DATA_LEN: usize = 4; - - let mut packet = [0u8; IPV4_HEADER_LEN + TCP_HEADER_LEN + TEST_DATA_LEN]; - let ipv4_source = Ipv4Addr::new(192, 168, 2, 1); - let ipv4_destination = Ipv4Addr::new(192, 168, 111, 51); - { - let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_header.set_next_level_protocol(IpNextLevelProtocol::Tcp); - ip_header.set_source(ipv4_source); - ip_header.set_destination(ipv4_destination); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_tcp_parse() { + let ref_packet = Bytes::from_static(&[ + 0xc1, 0x67, /* source */ + 0x23, 0x28, /* destination */ + 0x90, 0x37, 0xd2, 0xb8, /* seq */ + 0x94, 0x4b, 0xb2, 0x76, /* ack */ + 0x80, 0x18, 0x0f, 0xaf, /* offset+reserved, flags, win */ + 0xc0, 0x31, /* checksum */ + 0x00, 0x00, /* urg ptr */ + 0x01, 0x01, /* NOP */ + 0x08, 0x0a, 0x2c, 0x57, 0xcd, 0xa5, 0x02, 0xa0, 0x41, 0x92, /* timestamp */ + 0x74, 0x65, 0x73, 0x74, /* payload: "test" */ + ]); + let packet = TcpPacket::from_bytes(ref_packet.clone()).unwrap(); + + assert_eq!(packet.header.source, 0xc167); + assert_eq!(packet.header.destination, 0x2328); + assert_eq!(packet.header.sequence, 0x9037d2b8); + assert_eq!(packet.header.acknowledgement, 0x944bb276); + assert_eq!(packet.header.data_offset, 8); // adjusted + assert_eq!(packet.header.reserved, 0); + assert_eq!(packet.header.flags, 0x18); // PSH + ACK + assert_eq!(packet.header.window, 0x0faf); + assert_eq!(packet.header.checksum, 0xc031); + assert_eq!(packet.header.urgent_ptr, 0x0000); + assert_eq!(packet.header.options.len(), 3); + assert_eq!(packet.header.options[0].kind, TcpOptionKind::NOP); + assert_eq!(packet.header.options[1].kind, TcpOptionKind::NOP); + assert_eq!(packet.header.options[2].kind, TcpOptionKind::TIMESTAMPS); + assert_eq!( + packet.header.options[2].get_timestamp(), + (0x2c57cda5, 0x02a04192) + ); + assert_eq!(packet.payload, Bytes::from_static(b"test")); + assert_eq!(packet.header_len(), 32); // adjusted + assert_eq!(packet.to_bytes(), ref_packet); + assert_eq!(packet.header().len(), 32); // adjusted + assert_eq!(packet.payload().len(), 4); + } + + #[test] + fn test_basic_tcp_create() { + let options = vec![ + TcpOptionPacket::nop(), + TcpOptionPacket::nop(), + TcpOptionPacket::timestamp(0x2c57cda5, 0x02a04192), + ]; + + let packet = TcpPacket { + header: TcpHeader { + source: 0xc167, + destination: 0x2328, + sequence: 0x9037d2b8, + acknowledgement: 0x944bb276, + data_offset: 8.into(), // 8 * 4 = 32 bytes + reserved: 0.into(), + flags: 0x18, // PSH + ACK + window: 0x0faf, + checksum: 0xc031, + urgent_ptr: 0x0000, + options: options.clone(), + }, + payload: Bytes::from_static(b"test"), + }; + + let bytes = packet.to_bytes(); + let parsed = TcpPacket::from_bytes(bytes.clone()).expect("Failed to parse TCP packet"); + + assert_eq!(parsed, packet); + assert_eq!(parsed.to_bytes(), bytes); + assert_eq!(parsed.header.options.len(), 3); + assert_eq!(parsed.header.options[2].get_timestamp(), (0x2c57cda5, 0x02a04192)); } - // Set data - packet[IPV4_HEADER_LEN + TCP_HEADER_LEN] = 't' as u8; - packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 1] = 'e' as u8; - packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 2] = 's' as u8; - packet[IPV4_HEADER_LEN + TCP_HEADER_LEN + 3] = 't' as u8; - - { - let mut tcp_header = MutableTcpPacket::new(&mut packet[IPV4_HEADER_LEN..]).unwrap(); - tcp_header.set_source(49511); - assert_eq!(tcp_header.get_source(), 49511); - - tcp_header.set_destination(9000); - assert_eq!(tcp_header.get_destination(), 9000); - - tcp_header.set_sequence(0x9037d2b8); - assert_eq!(tcp_header.get_sequence(), 0x9037d2b8); - - tcp_header.set_acknowledgement(0x944bb276); - assert_eq!(tcp_header.get_acknowledgement(), 0x944bb276); - - tcp_header.set_flags(TcpFlags::PSH | TcpFlags::ACK); - assert_eq!(tcp_header.get_flags(), TcpFlags::PSH | TcpFlags::ACK); - - tcp_header.set_window(4015); - assert_eq!(tcp_header.get_window(), 4015); - - tcp_header.set_data_offset(8); - assert_eq!(tcp_header.get_data_offset(), 8); - - let ts = TcpOption::timestamp(743951781, 44056978); - tcp_header.set_options(&vec![TcpOption::nop(), TcpOption::nop(), ts]); - - let checksum = ipv4_checksum(&tcp_header.to_immutable(), &ipv4_source, &ipv4_destination); - tcp_header.set_checksum(checksum); - assert_eq!(tcp_header.get_checksum(), 0xc031); - } - let ref_packet = [ - 0xc1, 0x67, /* source */ - 0x23, 0x28, /* destination */ - 0x90, 0x37, 0xd2, 0xb8, /* seq */ - 0x94, 0x4b, 0xb2, 0x76, /* ack */ - 0x80, 0x18, 0x0f, 0xaf, /* length, flags, win */ - 0xc0, 0x31, /* checksum */ - 0x00, 0x00, /* urg ptr */ - 0x01, 0x01, /* options: nop */ - 0x08, 0x0a, 0x2c, 0x57, 0xcd, 0xa5, 0x02, 0xa0, 0x41, 0x92, /* timestamp */ - 0x74, 0x65, 0x73, 0x74, /* "test" */ - ]; - assert_eq!(&ref_packet[..], &packet[20..]); -} - -#[test] -fn tcp_test_options_invalid_offset() { - let mut buf = [0; 20]; // no space for options - { - if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) { - tcp.set_data_offset(10); // set invalid offset - } - } - - if let Some(tcp) = TcpPacket::new(&buf[..]) { - let _options = tcp.get_options_iter(); // shouldn't crash here - } -} - -#[test] -fn tcp_test_options_vec_invalid_offset() { - let mut buf = [0; 20]; // no space for options - { - if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) { - tcp.set_data_offset(10); // set invalid offset - } - } - - if let Some(tcp) = TcpPacket::new(&buf[..]) { - let _options = tcp.get_options(); // shouldn't crash here - } -} - -#[test] -fn tcp_test_options_slice_invalid_offset() { - let mut buf = [0; 20]; // no space for options - { - if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) { - tcp.set_data_offset(10); // set invalid offset - } - } - - if let Some(tcp) = TcpPacket::new(&buf[..]) { - let _options = tcp.get_options_raw(); // shouldn't crash here - } -} - -#[test] -fn tcp_test_option_invalid_len() { - use std::println; - let mut buf = [0; 24]; - { - if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) { - tcp.set_data_offset(6); - } - buf[20] = 2; // option type - buf[21] = 8; // option len, not enough space for it - } - - if let Some(tcp) = TcpPacket::new(&buf[..]) { - let options = tcp.get_options_iter(); - for opt in options { - println!("{:?}", opt); - } - } -} - -#[test] -fn tcp_test_payload_slice_invalid_offset() { - let mut buf = [0; 20]; - { - if let Some(mut tcp) = MutableTcpPacket::new(&mut buf[..]) { - tcp.set_data_offset(10); // set invalid offset - } - } - - if let Some(tcp) = TcpPacket::new(&buf[..]) { - assert_eq!(tcp.payload().len(), 0); - } } diff --git a/nex-packet/src/udp.rs b/nex-packet/src/udp.rs index c6332f7..8826772 100644 --- a/nex-packet/src/udp.rs +++ b/nex-packet/src/udp.rs @@ -1,16 +1,13 @@ //! A UDP packet abstraction. -use crate::ip::IpNextLevelProtocol; -use crate::Packet; - -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; +use crate::ip::IpNextProtocol; +use crate::packet::Packet; use crate::util; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use nex_core::bitfield::u16be; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -27,42 +24,94 @@ pub struct UdpHeader { pub checksum: u16be, } -impl UdpHeader { - /// Construct a UDP header from a byte slice. - pub fn from_bytes(packet: &[u8]) -> Result { - if packet.len() < UDP_HEADER_LEN { - return Err("Packet is too small for UDP header".to_string()); +/// Represents a UDP Packet. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct UdpPacket { + pub header: UdpHeader, + pub payload: Bytes, +} + +impl Packet for UdpPacket { + type Header = UdpHeader; + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < UDP_HEADER_LEN { + return None; } - match UdpPacket::new(packet) { - Some(udp_packet) => Ok(UdpHeader { - source: udp_packet.get_source(), - destination: udp_packet.get_destination(), - length: udp_packet.get_length(), - checksum: udp_packet.get_checksum(), - }), - None => Err("Failed to parse UDP packet".to_string()), + + let source = bytes.get_u16(); + let destination = bytes.get_u16(); + let length = bytes.get_u16(); + let checksum = bytes.get_u16(); + + if length < UDP_HEADER_LEN as u16 { + return None; } - } - /// Construct a UDP header from a UdpPacket. - pub(crate) fn from_packet(udp_packet: &UdpPacket) -> UdpHeader { - UdpHeader { - source: udp_packet.get_source(), - destination: udp_packet.get_destination(), - length: udp_packet.get_length(), - checksum: udp_packet.get_checksum(), + + let payload_len = length as usize - UDP_HEADER_LEN; + if bytes.len() < payload_len { + return None; } + + let (payload_slice, _) = bytes.split_at(payload_len); + + Some(UdpPacket { + header: UdpHeader { + source, + destination, + length, + checksum, + }, + payload: Bytes::copy_from_slice(payload_slice), + }) + } + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN + self.payload.len()); + buf.put_u16(self.header.source); + buf.put_u16(self.header.destination); + buf.put_u16((UDP_HEADER_LEN + self.payload.len()) as u16); + buf.put_u16(self.header.checksum); + buf.extend_from_slice(&self.payload); + buf.freeze() + } + fn header(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN); + buf.put_u16(self.header.source); + buf.put_u16(self.header.destination); + buf.put_u16(self.header.length); + buf.put_u16(self.header.checksum); + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + UDP_HEADER_LEN + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) } } -/// Represents a UDP Packet. -#[packet] -pub struct Udp { - pub source: u16be, - pub destination: u16be, - pub length: u16be, - pub checksum: u16be, - #[payload] - pub payload: Vec, +pub fn checksum(packet: &UdpPacket, source: &IpAddr, destination: &IpAddr) -> u16 { + match (source, destination) { + (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst), + (IpAddr::V6(src), IpAddr::V6(dst)) => ipv6_checksum(packet, src, dst), + _ => 0, // Unsupported IP version + } } /// Calculate a checksum for a packet built on IPv4. @@ -84,61 +133,15 @@ pub fn ipv4_checksum_adv( destination: &Ipv4Addr, ) -> u16be { util::ipv4_checksum( - packet.packet(), + packet.to_bytes().as_ref(), 3, extra_data, source, destination, - IpNextLevelProtocol::Udp, + IpNextProtocol::Udp, ) } -#[test] -fn udp_header_ipv4_test() { - use crate::ip::IpNextLevelProtocol; - use crate::ipv4::MutableIpv4Packet; - - let mut packet = [0u8; 20 + 8 + 4]; - let ipv4_source = Ipv4Addr::new(192, 168, 0, 1); - let ipv4_destination = Ipv4Addr::new(192, 168, 0, 199); - { - let mut ip_header = MutableIpv4Packet::new(&mut packet[..]).unwrap(); - ip_header.set_next_level_protocol(IpNextLevelProtocol::Udp); - ip_header.set_source(ipv4_source); - ip_header.set_destination(ipv4_destination); - } - - // Set data - packet[20 + 8] = 't' as u8; - packet[20 + 8 + 1] = 'e' as u8; - packet[20 + 8 + 2] = 's' as u8; - packet[20 + 8 + 3] = 't' as u8; - - { - let mut udp_header = MutableUdpPacket::new(&mut packet[20..]).unwrap(); - udp_header.set_source(12345); - assert_eq!(udp_header.get_source(), 12345); - - udp_header.set_destination(54321); - assert_eq!(udp_header.get_destination(), 54321); - - udp_header.set_length(8 + 4); - assert_eq!(udp_header.get_length(), 8 + 4); - - let checksum = ipv4_checksum(&udp_header.to_immutable(), &ipv4_source, &ipv4_destination); - udp_header.set_checksum(checksum); - assert_eq!(udp_header.get_checksum(), 0x9178); - } - - let ref_packet = [ - 0x30, 0x39, /* source */ - 0xd4, 0x31, /* destination */ - 0x00, 0x0c, /* length */ - 0x91, 0x78, /* checksum */ - ]; - assert_eq!(&ref_packet[..], &packet[20..28]); -} - /// Calculate a checksum for a packet built on IPv6. pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be { ipv6_checksum_adv(packet, &[], source, destination) @@ -158,57 +161,59 @@ pub fn ipv6_checksum_adv( destination: &Ipv6Addr, ) -> u16be { util::ipv6_checksum( - packet.packet(), + packet.to_bytes().as_ref(), 3, extra_data, source, destination, - IpNextLevelProtocol::Udp, + IpNextProtocol::Udp, ) } -#[test] -fn udp_header_ipv6_test() { - use crate::ip::IpNextLevelProtocol; - use crate::ipv6::MutableIpv6Packet; - - let mut packet = [0u8; 40 + 8 + 4]; - let ipv6_source = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); - let ipv6_destination = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); - { - let mut ip_header = MutableIpv6Packet::new(&mut packet[..]).unwrap(); - ip_header.set_next_header(IpNextLevelProtocol::Udp); - ip_header.set_source(ipv6_source); - ip_header.set_destination(ipv6_destination); +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_basic_udp_parse() { + let raw = Bytes::from_static(&[ + 0x12, 0x34, // source + 0xab, 0xcd, // destination + 0x00, 0x0c, // length = 12 bytes (8 header + 4 payload) + 0x55, 0xaa, // checksum + b'd', b'a', b't', b'a', // payload + ]); + let packet = UdpPacket::from_bytes(raw.clone()).expect("Failed to parse UDP packet"); + + assert_eq!(packet.header.source, 0x1234); + assert_eq!(packet.header.destination, 0xabcd); + assert_eq!(packet.header.length, 12); + assert_eq!(packet.header.checksum, 0x55aa); + assert_eq!(packet.payload, Bytes::from_static(b"data")); + assert_eq!(packet.to_bytes(), raw); } - - // Set data - packet[40 + 8] = 't' as u8; - packet[40 + 8 + 1] = 'e' as u8; - packet[40 + 8 + 2] = 's' as u8; - packet[40 + 8 + 3] = 't' as u8; - - { - let mut udp_header = MutableUdpPacket::new(&mut packet[40..]).unwrap(); - udp_header.set_source(12345); - assert_eq!(udp_header.get_source(), 12345); - - udp_header.set_destination(54321); - assert_eq!(udp_header.get_destination(), 54321); - - udp_header.set_length(8 + 4); - assert_eq!(udp_header.get_length(), 8 + 4); - - let checksum = ipv6_checksum(&udp_header.to_immutable(), &ipv6_source, &ipv6_destination); - udp_header.set_checksum(checksum); - assert_eq!(udp_header.get_checksum(), 0x1390); + #[test] + fn test_basic_udp_create() { + let payload = Bytes::from_static(b"data"); + let packet = UdpPacket { + header: UdpHeader { + source: 0x1234, + destination: 0xabcd, + length: (UDP_HEADER_LEN + payload.len()) as u16, + checksum: 0x55aa, + }, + payload: payload.clone(), + }; + + let expected = Bytes::from_static(&[ + 0x12, 0x34, // source + 0xab, 0xcd, // destination + 0x00, 0x0c, // length + 0x55, 0xaa, // checksum + b'd', b'a', b't', b'a', // payload + ]); + + assert_eq!(packet.to_bytes(), expected); + assert_eq!(packet.payload(), payload); + assert_eq!(packet.header_len(), UDP_HEADER_LEN); } - - let ref_packet = [ - 0x30, 0x39, /* source */ - 0xd4, 0x31, /* destination */ - 0x00, 0x0c, /* length */ - 0x13, 0x90, /* checksum */ - ]; - assert_eq!(&ref_packet[..], &packet[40..48]); } diff --git a/nex-packet/src/usbpcap.rs b/nex-packet/src/usbpcap.rs deleted file mode 100644 index 157f9b0..0000000 --- a/nex-packet/src/usbpcap.rs +++ /dev/null @@ -1,172 +0,0 @@ -//! A USB PCAP packet abstraction. - -use alloc::vec::Vec; - -use nex_macro::Packet; -use nex_macro_helper::packet::PrimitiveValues; -use nex_macro_helper::types::{u1, u16le, u3, u32le, u4, u64le, u7}; - -/// Represents a USB PCAP function for the requested operation. -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct UsbPcapFunction(pub u16); - -impl UsbPcapFunction { - /// Construct a new `UsbPcapFunction` instance. - pub fn new(val: u16) -> Self { - Self(val) - } -} - -impl PrimitiveValues for UsbPcapFunction { - type T = (u16,); - fn to_primitive_values(&self) -> Self::T { - (self.0,) - } -} - -/// Represents the USB status for USB requests. -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct UsbPcapStatus(pub u32); - -impl UsbPcapStatus { - /// Construct a new `UsbPcapStatus` instance. - pub fn new(val: u32) -> Self { - Self(val) - } -} - -impl PrimitiveValues for UsbPcapStatus { - type T = (u32,); - fn to_primitive_values(&self) -> Self::T { - (self.0,) - } -} - -/// Represents a USB PCAP packet ([Link Type 249](https://www.tcpdump.org/linktypes.html)). -#[derive(Packet)] -pub struct UsbPcap { - pub header_length: u16le, - pub irp_id: u64le, - #[construct_with(u32le)] - pub status: UsbPcapStatus, - #[construct_with(u16le)] - pub function: UsbPcapFunction, - pub reserved_info: u7, - pub pdo_to_fdo: u1, - pub bus: u16le, - pub device: u16le, - pub direction: u1, - pub reserved_endpoint: u3, - pub endpoint: u4, - pub transfer: u8, - pub data_length: u32le, - #[length = "header_length - 27"] - pub header_payload: Vec, - #[length = "data_length"] - #[payload] - pub payload: Vec, -} - -#[cfg(test)] -mod tests { - use super::*; - use nex_macro_helper::packet::Packet; - - #[test] - fn usbpcap_packet_test() { - let mut packet = [0u8; 35]; - { - let mut usbpcap = MutableUsbPcapPacket::new(&mut packet[..]).unwrap(); - usbpcap.set_header_length(27); - assert_eq!(usbpcap.get_header_length(), 27); - - usbpcap.set_irp_id(0x12_34); - assert_eq!(usbpcap.get_irp_id(), 0x12_34); - - usbpcap.set_status(UsbPcapStatus(30)); - assert_eq!(usbpcap.get_status(), UsbPcapStatus(30)); - - usbpcap.set_function(UsbPcapFunction(40)); - assert_eq!(usbpcap.get_function(), UsbPcapFunction(40)); - - assert_eq!(usbpcap.get_reserved_info(), 0); - - usbpcap.set_pdo_to_fdo(1); - assert_eq!(usbpcap.get_pdo_to_fdo(), 1); - - usbpcap.set_bus(60); - assert_eq!(usbpcap.get_bus(), 60); - - usbpcap.set_device(70); - assert_eq!(usbpcap.get_device(), 70); - - usbpcap.set_direction(1); - assert_eq!(usbpcap.get_direction(), 1); - - assert_eq!(usbpcap.get_reserved_endpoint(), 0); - - usbpcap.set_endpoint(14); - assert_eq!(usbpcap.get_endpoint(), 14); - - usbpcap.set_transfer(80); - assert_eq!(usbpcap.get_transfer(), 80); - - usbpcap.set_data_length(2); - assert_eq!(usbpcap.get_data_length(), 2); - - assert_eq!(usbpcap.get_header_payload(), Vec::::new()); - - usbpcap.set_payload(&[90, 100]); - assert_eq!(usbpcap.payload(), &[90, 100]); - } - - let ref_packet = [ - 27, 0, // Header length - 0x34, 0x12, 0, 0, 0, 0, 0, 0, // IRP ID - 30, 0, 0, 0, // Status - 40, 0, // Function - 1, // Info octet - 60, 0, // Bus - 70, 0, // Device - 142, // Endpoint fields - 80, // Transfer field - 2, 0, 0, 0, // Data length field - // No header payload - 90, 100, // Payload - ]; - - assert_eq!(&ref_packet[..], &packet[0..29]); - } - - #[test] - fn usbpcap_packet_test_variable_header() { - let mut packet = [0u8; 35]; - { - let mut usbpcap = MutableUsbPcapPacket::new(&mut packet[..]).unwrap(); - usbpcap.set_header_length(28); - assert_eq!(usbpcap.get_header_length(), 28); - - usbpcap.set_header_payload(&[110]); - assert_eq!(usbpcap.get_header_payload(), &[110]); - - assert_eq!(usbpcap.payload(), Vec::::new()); - } - - let ref_packet = [ - 28, 0, // Header length - 0, 0, 0, 0, 0, 0, 0, 0, // IRP ID - 0, 0, 0, 0, // Status - 0, 0, // Function - 0, // Info - 0, 0, // Bus - 0, 0, // Device - 0, // Endpoint fields - 0, // Transfer field - 0, 0, 0, 0, // Data length field - 110, // Header payload - // No payload - ]; - - assert_eq!(&ref_packet[..], &packet[0..28]); - } -} diff --git a/nex-packet/src/util.rs b/nex-packet/src/util.rs index d255ca9..d488143 100644 --- a/nex-packet/src/util.rs +++ b/nex-packet/src/util.rs @@ -1,7 +1,7 @@ //! Utilities for working with packets, eg. checksumming. -use crate::ip::IpNextLevelProtocol; -use nex_macro_helper::types::u16be; +use crate::ip::IpNextProtocol; +use nex_core::bitfield::u16be; use core::convert::TryInto; use core::u16; @@ -87,7 +87,7 @@ pub fn ipv4_checksum( extra_data: &[u8], source: &Ipv4Addr, destination: &Ipv4Addr, - next_level_protocol: IpNextLevelProtocol, + next_level_protocol: IpNextProtocol, ) -> u16be { let mut sum = 0u32; @@ -118,7 +118,7 @@ pub fn ipv6_checksum( extra_data: &[u8], source: &Ipv6Addr, destination: &Ipv6Addr, - next_level_protocol: IpNextLevelProtocol, + next_level_protocol: IpNextProtocol, ) -> u16be { let mut sum = 0u32; @@ -171,7 +171,6 @@ fn sum_be_words(data: &[u8], skipword: usize) -> u32 { #[cfg(test)] mod tests { use super::sum_be_words; - use alloc::{vec, vec::Vec}; use core::slice; #[test] diff --git a/nex-packet/src/vlan.rs b/nex-packet/src/vlan.rs index 5bbdfaf..ed16a4b 100644 --- a/nex-packet/src/vlan.rs +++ b/nex-packet/src/vlan.rs @@ -1,114 +1,209 @@ -//! A VLAN packet abstraction. +//! A VLAN (802.1Q) packet abstraction. +//! +use crate::{ethernet::EtherType, packet::Packet}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use nex_core::bitfield::{u1, u12be}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -use crate::ethernet::EtherType; -use crate::PrimitiveValues; +/// VLAN Header length in bytes +pub const VLAN_HEADER_LEN: usize = 4; -use alloc::vec::Vec; - -use nex_macro::packet; -use nex_macro_helper::types::*; - -/// Represents an IEEE 802.1p class of a service. -/// +/// Class of Service (IEEE 802.1p Priority Code Point) #[repr(u8)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ClassOfService { - /// Background - BK = 1, - /// Best Effort - BE = 0, - /// Excellent Effort - EE = 2, - /// Critical Applications - CA = 3, - /// Video, < 100 ms latency - VI = 4, - /// Voice, < 10 ms latency - VO = 5, - /// Internetwork Control - IC = 6, - /// Network Control - NC = 7, - /// Unknown class of service - Unknown(u3), + // Background + BK = 1, + // Best Effort + BE = 0, + // Excellent Effort + EE = 2, + // Critical Applications + CA = 3, + // Video + VI = 4, + // Voice + VO = 5, + // Internetwork Control + IC = 6, + // Network Control + NC = 7, + /// Unknown Class of Service + Unknown(u8), } impl ClassOfService { - /// Constructs a new ClassOfServiceEnum from u3. - pub fn new(value: u3) -> ClassOfService { - match value { - 1 => ClassOfService::BK, + pub fn new(val: u8) -> Self { + match val { 0 => ClassOfService::BE, + 1 => ClassOfService::BK, 2 => ClassOfService::EE, 3 => ClassOfService::CA, 4 => ClassOfService::VI, 5 => ClassOfService::VO, 6 => ClassOfService::IC, 7 => ClassOfService::NC, - _ => ClassOfService::Unknown(value), + other => ClassOfService::Unknown(other), } } -} -impl PrimitiveValues for ClassOfService { - type T = (u3,); - fn to_primitive_values(&self) -> (u3,) { + pub fn value(&self) -> u8 { match *self { - ClassOfService::BK => (1,), - ClassOfService::BE => (0,), - ClassOfService::EE => (2,), - ClassOfService::CA => (3,), - ClassOfService::VI => (4,), - ClassOfService::VO => (5,), - ClassOfService::IC => (6,), - ClassOfService::NC => (7,), - ClassOfService::Unknown(n) => (n,), + ClassOfService::BK => 1, + ClassOfService::BE => 0, + ClassOfService::EE => 2, + ClassOfService::CA => 3, + ClassOfService::VI => 4, + ClassOfService::VO => 5, + ClassOfService::IC => 6, + ClassOfService::NC => 7, + ClassOfService::Unknown(v) => v, } } } -/// Represents a VLAN-tagged packet. -#[packet] -pub struct Vlan { - #[construct_with(u3)] +/// VLAN header structure +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct VlanHeader { pub priority_code_point: ClassOfService, - pub drop_eligible_indicator: u1, - pub vlan_identifier: u12be, - #[construct_with(u16be)] + pub drop_eligible_id: u1, + pub vlan_id: u12be, pub ethertype: EtherType, - #[payload] - pub payload: Vec, +} + +/// VLAN packet +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VlanPacket { + pub header: VlanHeader, + pub payload: Bytes, +} + +impl Packet for VlanPacket { + type Header = VlanHeader; + + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < VLAN_HEADER_LEN { + return None; + } + + // VLAN TCI + let tci = bytes.get_u16(); + let raw_pcp = ((tci >> 13) & 0b111) as u8; + println!("DEBUG: tci=0x{:04x}, raw_pcp={}", tci, raw_pcp); + let pcp = ClassOfService::new(((tci >> 13) & 0b111) as u8); + let drop_eligible_id = ((tci >> 12) & 0b1) as u1; + let vlan_id = (tci & 0x0FFF) as u12be; + + // EtherType + let ethertype = EtherType::new(bytes.get_u16()); + + // Payload + Some(VlanPacket { + header: VlanHeader { + priority_code_point: pcp, + drop_eligible_id, + vlan_id, + ethertype, + }, + payload: Bytes::copy_from_slice(bytes), + }) + } + fn from_bytes(mut bytes: Bytes) -> Option { + Self::from_buf(&mut bytes) + } + + fn to_bytes(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(VLAN_HEADER_LEN + self.payload.len()); + + let pcp_bits = (self.header.priority_code_point.value() as u16 & 0b111) << 13; + let dei_bits = (self.header.drop_eligible_id as u16 & 0b1) << 12; + let vlan_bits = self.header.vlan_id as u16 & 0x0FFF; + + let tci = pcp_bits | dei_bits | vlan_bits; + + buf.put_u16(tci); + buf.put_u16(self.header.ethertype.value()); + buf.extend_from_slice(&self.payload); + + buf.freeze() + } + + fn header(&self) -> Bytes { + let mut buf = BytesMut::with_capacity(VLAN_HEADER_LEN); + + let mut first = (self.header.priority_code_point.value() & 0b111) << 5; + first |= (self.header.drop_eligible_id & 0b1) << 4; + first |= ((self.header.vlan_id >> 8) & 0b0000_1111) as u8; + + let second = (self.header.vlan_id & 0xFF) as u8; + + buf.put_u8(first); + buf.put_u8(second); + buf.put_u16(self.header.ethertype.value()); + + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + VLAN_HEADER_LEN + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + (self.header, self.payload) + } } #[cfg(test)] mod tests { use super::*; - use crate::ethernet::EtherType; #[test] - fn vlan_packet_test() { - let mut packet = [0u8; 4]; - { - let mut vlan_header = MutableVlanPacket::new(&mut packet[..]).unwrap(); - vlan_header.set_priority_code_point(ClassOfService::BE); - assert_eq!(vlan_header.get_priority_code_point(), ClassOfService::BE); + fn test_vlan_parse() { + let raw = Bytes::from_static(&[ + 0x20, 0x00, // TCI: pcp=1 (BK), dei=0, vid=0 + 0x08, 0x00, // EtherType: IPv4 + b'x', b'y', b'z', + ]); - vlan_header.set_drop_eligible_indicator(0); - assert_eq!(vlan_header.get_drop_eligible_indicator(), 0); + let packet = VlanPacket::from_bytes(raw.clone()).unwrap(); - vlan_header.set_ethertype(EtherType::Ipv4); - assert_eq!(vlan_header.get_ethertype(), EtherType::Ipv4); + assert_eq!(packet.header.priority_code_point, ClassOfService::BK); + assert_eq!(packet.header.drop_eligible_id, 0); + assert_eq!(packet.header.vlan_id, 0x000); + assert_eq!(packet.header.ethertype, EtherType::Ipv4); + assert_eq!(packet.payload, Bytes::from_static(b"xyz")); + assert_eq!(packet.to_bytes(), raw); + } + #[test] + fn test_vlan_parse_2() { + let raw = Bytes::from_static(&[ + 0x01, 0x00, // TCI: PCP=0(BE), DEI=0, VID=0x100 + 0x08, 0x00, // EtherType: IPv4 + b'x', b'y', b'z', + ]); - vlan_header.set_vlan_identifier(0x100); - assert_eq!(vlan_header.get_vlan_identifier(), 0x100); - } + let packet = VlanPacket::from_bytes(raw.clone()).unwrap(); - let ref_packet = [ - 0x01, // PCP, DEI, and first nibble of VID - 0x00, // Remainder of VID - 0x08, // First byte of ethertype - 0x00, - ]; // Second byte of ethertype - assert_eq!(&ref_packet[..], &packet[..]); + assert_eq!(packet.header.priority_code_point, ClassOfService::BE); + assert_eq!(packet.header.drop_eligible_id, 0); + assert_eq!(packet.header.vlan_id, 0x100); + assert_eq!(packet.header.ethertype, EtherType::Ipv4); + assert_eq!(packet.payload, Bytes::from_static(b"xyz")); + assert_eq!(packet.to_bytes(), raw); } } diff --git a/nex-packet/src/vxlan.rs b/nex-packet/src/vxlan.rs new file mode 100644 index 0000000..84cd300 --- /dev/null +++ b/nex-packet/src/vxlan.rs @@ -0,0 +1,121 @@ +//! A VXLAN packet abstraction. +use bytes::{Buf, Bytes}; +use nex_core::bitfield::{self, u24be}; + +use crate::packet::Packet; + +/// Virtual eXtensible Local Area Network (VXLAN) +/// +/// See [RFC 7348](https://datatracker.ietf.org/doc/html/rfc7348) +/// +/// VXLAN Header: +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |R|R|R|R|I|R|R|R| Reserved | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | VXLAN Network Identifier (VNI) | Reserved | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +pub struct Vxlan { + pub flags: u8, + pub reserved1: u24be, + pub vni: u24be, + pub reserved2: u8, + pub payload: Bytes, +} + +impl Packet for Vxlan { + type Header = (); + + fn from_buf(mut bytes: &[u8]) -> Option { + if bytes.len() < 8 { + return None; + } + + let flags = bytes.get_u8(); + + let reserved1 = { + let b1 = bytes.get_u8(); + let b2 = bytes.get_u8(); + let b3 = bytes.get_u8(); + bitfield::utils::u24be_from_bytes([b1, b2, b3]) + }; + + let vni = { + let b1 = bytes.get_u8(); + let b2 = bytes.get_u8(); + let b3 = bytes.get_u8(); + bitfield::utils::u24be_from_bytes([b1, b2, b3]) + }; + + let reserved2 = bytes.get_u8(); + + let payload = Bytes::copy_from_slice(bytes); + + Some(Self { + flags, + reserved1, + vni, + reserved2, + payload, + }) + } + + fn from_bytes(bytes: Bytes) -> Option { + Self::from_buf(&bytes) + } + + fn to_bytes(&self) -> Bytes { + use bytes::BufMut; + let mut buf = bytes::BytesMut::with_capacity(8 + self.payload.len()); + + buf.put_u8(self.flags); + buf.put_slice(&bitfield::utils::u24be_to_bytes(self.reserved1)); + buf.put_slice(&bitfield::utils::u24be_to_bytes(self.vni)); + buf.put_u8(self.reserved2); + buf.put_slice(&self.payload); + + buf.freeze() + } + fn header(&self) -> Bytes { + use bytes::BufMut; + let mut buf = bytes::BytesMut::with_capacity(8); + + buf.put_u8(self.flags); + buf.put_slice(&self.reserved1.to_be_bytes()); + buf.put_slice(&self.vni.to_be_bytes()); + buf.put_u8(self.reserved2); + + buf.freeze() + } + + fn payload(&self) -> Bytes { + self.payload.clone() + } + + fn header_len(&self) -> usize { + 8 + } + + fn payload_len(&self) -> usize { + self.payload.len() + } + + fn total_len(&self) -> usize { + self.header_len() + self.payload_len() + } + + fn into_parts(self) -> (Self::Header, Bytes) { + ((), self.payload) + } +} + +#[test] +fn vxlan_packet_test() { + let packet = Bytes::from_static(&[ + 0x08, // I flag + 0x00, 0x00, 0x00, // Reserved + 0x12, 0x34, 0x56, // VNI + 0x00 // Reserved + ]); + let vxlan_packet = Vxlan::from_bytes(packet.clone()).unwrap(); + assert_eq!(vxlan_packet.to_bytes(), packet); +} diff --git a/nex-socket/Cargo.toml b/nex-socket/Cargo.toml index bc821fa..b12a51e 100644 --- a/nex-socket/Cargo.toml +++ b/nex-socket/Cargo.toml @@ -11,11 +11,14 @@ categories = ["network-programming"] license = "MIT" [dependencies] -async-io = "2.4" -futures-lite = "2.6" -futures-io = "0.3" -socket2 = { version = "0.5", features = ["all"] } +nex-core = { workspace = true } nex-packet = { workspace = true } +socket2 = { version = "0.5", features = ["all"] } +tokio = { version = "1", features = ["time", "sync", "net", "rt"] } +libc = { workspace = true } + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.30", features = ["poll"] } [target.'cfg(windows)'.dependencies.windows-sys] version = "0.59.0" diff --git a/nex-socket/src/icmp/async_impl.rs b/nex-socket/src/icmp/async_impl.rs new file mode 100644 index 0000000..7d361b1 --- /dev/null +++ b/nex-socket/src/icmp/async_impl.rs @@ -0,0 +1,123 @@ +use crate::icmp::{IcmpConfig, IcmpKind}; +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, UdpSocket as StdUdpSocket}; +use tokio::net::UdpSocket; + +/// Asynchronous ICMP socket built on Tokio. +#[derive(Debug)] +pub struct AsyncIcmpSocket { + inner: UdpSocket, + sock_type: SockType, + kind: IcmpKind, +} + +impl AsyncIcmpSocket { + /// Create a new asynchronous ICMP socket. + pub async fn new(config: &IcmpConfig) -> io::Result { + let (domain, proto) = match config.kind { + IcmpKind::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)), + IcmpKind::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)), + }; + + // Build the socket with DGRAM preferred and RAW as a fallback + let socket = match Socket::new(domain, config.sock_type_hint, proto) { + Ok(s) => s, + Err(_) => { + let alt_type = if config.sock_type_hint == SockType::DGRAM { + SockType::RAW + } else { + SockType::DGRAM + }; + Socket::new(domain, alt_type, proto)? + } + }; + + socket.set_nonblocking(true)?; + + // bind + if let Some(addr) = &config.bind { + socket.bind(&(*addr).into())?; + } + + // Linux: optional interface name + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(interface) = &config.interface { + socket.bind_device(Some(interface.as_bytes()))?; + } + + // TTL + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + // FreeBSD only: optional FIB support + #[cfg(target_os = "freebsd")] + if let Some(fib) = config.fib { + socket.set_fib(fib)?; + } + + let socket_type = socket.r#type()?; + + // Convert socket2::Socket into std::net::UdpSocket + #[cfg(windows)] + let std_socket = unsafe { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + StdUdpSocket::from_raw_socket(socket.into_raw_socket()) + }; + #[cfg(unix)] + let std_socket = unsafe { + use std::os::fd::{FromRawFd, IntoRawFd}; + + StdUdpSocket::from_raw_fd(socket.into_raw_fd()) + }; + + // std → tokio::net::UdpSocket + let inner = UdpSocket::from_std(std_socket)?; + + Ok(Self { + inner, + sock_type: socket_type, + kind: config.kind, + }) + } + + /// Send a packet asynchronously. + pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.inner.send_to(buf, target).await + } + + /// Receive a packet asynchronously. + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.inner.recv_from(buf).await + } + + /// Retrieve the local address. + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + /// Return the socket type (DGRAM or RAW). + pub fn sock_type(&self) -> SockType { + self.sock_type + } + + /// Return the ICMP version. + pub fn kind(&self) -> IcmpKind { + self.kind + } + + /// Access the native socket for low level operations. + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.inner.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.inner.as_raw_socket() + } +} diff --git a/nex-socket/src/icmp/config.rs b/nex-socket/src/icmp/config.rs new file mode 100644 index 0000000..7b083a6 --- /dev/null +++ b/nex-socket/src/icmp/config.rs @@ -0,0 +1,79 @@ +use std::net::SocketAddr; +use socket2::Type as SockType; + +/// ICMP protocol version. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IcmpKind { + V4, + V6, +} + +/// Configuration for an ICMP socket. +#[derive(Debug, Clone)] +pub struct IcmpConfig { + pub kind: IcmpKind, + pub bind: Option, + pub ttl: Option, + pub interface: Option, + pub sock_type_hint: SockType, + pub fib: Option, +} + +impl IcmpConfig { + pub fn new(kind: IcmpKind) -> Self { + Self { + kind, + bind: None, + ttl: None, + interface: None, + sock_type_hint: SockType::DGRAM, // DGRAM preferred on Linux, RAW fallback on macOS/Windows + fib: None, // FreeBSD only + } + } + + pub fn with_bind(mut self, addr: SocketAddr) -> Self { + self.bind = Some(addr); + self + } + + pub fn with_ttl(mut self, ttl: u32) -> Self { + self.ttl = Some(ttl); + self + } + + pub fn with_interface(mut self, iface: impl Into) -> Self { + self.interface = Some(iface.into()); + self + } + + pub fn with_sock_type(mut self, ty: SockType) -> Self { + self.sock_type_hint = ty; + self + } + + pub fn with_fib(mut self, fib: u32) -> Self { + self.fib = Some(fib); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use socket2::Type; + #[test] + fn icmp_config_builders() { + let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + let cfg = IcmpConfig::new(IcmpKind::V4) + .with_bind(addr) + .with_ttl(4) + .with_interface("eth0") + .with_sock_type(Type::RAW); + assert_eq!(cfg.kind, IcmpKind::V4); + assert_eq!(cfg.bind, Some(addr)); + assert_eq!(cfg.ttl, Some(4)); + assert_eq!(cfg.interface.as_deref(), Some("eth0")); + assert_eq!(cfg.sock_type_hint, Type::RAW); + } +} + diff --git a/nex-socket/src/icmp/mod.rs b/nex-socket/src/icmp/mod.rs new file mode 100644 index 0000000..043548c --- /dev/null +++ b/nex-socket/src/icmp/mod.rs @@ -0,0 +1,7 @@ +mod config; +mod async_impl; +mod sync_impl; + +pub use config::*; +pub use sync_impl::*; +pub use async_impl::*; diff --git a/nex-socket/src/icmp/sync_impl.rs b/nex-socket/src/icmp/sync_impl.rs new file mode 100644 index 0000000..a9c28ba --- /dev/null +++ b/nex-socket/src/icmp/sync_impl.rs @@ -0,0 +1,101 @@ +use crate::icmp::{IcmpConfig, IcmpKind}; +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, UdpSocket}; + +/// Synchronous ICMP socket. +#[derive(Debug)] +pub struct IcmpSocket { + inner: UdpSocket, + sock_type: SockType, + kind: IcmpKind, +} + +impl IcmpSocket { + /// Create a new synchronous ICMP socket. + pub fn new(config: &IcmpConfig) -> io::Result { + let (domain, proto) = match config.kind { + IcmpKind::V4 => (Domain::IPV4, Some(Protocol::ICMPV4)), + IcmpKind::V6 => (Domain::IPV6, Some(Protocol::ICMPV6)), + }; + + let socket = match Socket::new(domain, config.sock_type_hint, proto) { + Ok(s) => s, + Err(_) => { + let alt_type = if config.sock_type_hint == SockType::DGRAM { + SockType::RAW + } else { + SockType::DGRAM + }; + Socket::new(domain, alt_type, proto)? + } + }; + + socket.set_nonblocking(false)?; // blocking mode for sync usage + + if let Some(addr) = &config.bind { + socket.bind(&(*addr).into())?; + } + + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(interface) = &config.interface { + socket.bind_device(Some(interface.as_bytes()))?; + } + + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + #[cfg(target_os = "freebsd")] + if let Some(fib) = config.fib { + socket.set_fib(fib)?; + } + + // Convert socket2::Socket into std::net::UdpSocket + let std_socket: UdpSocket = socket.into(); + + Ok(Self { + inner: std_socket, + sock_type: config.sock_type_hint, + kind: config.kind, + }) + } + + /// Send a packet. + pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.inner.send_to(buf, target) + } + + /// Receive a packet. + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.inner.recv_from(buf) + } + + /// Retrieve the local address. + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + /// Return the socket type. + pub fn sock_type(&self) -> SockType { + self.sock_type + } + + /// Return the ICMP variant. + pub fn kind(&self) -> IcmpKind { + self.kind + } + + /// Access the underlying socket. + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.inner.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.inner.as_raw_socket() + } +} diff --git a/nex-socket/src/lib.rs b/nex-socket/src/lib.rs index a6f18c5..bad1362 100644 --- a/nex-socket/src/lib.rs +++ b/nex-socket/src/lib.rs @@ -1,10 +1,9 @@ -mod socket; -mod sys; +//! Convenience sockets built on top of `socket2` and `tokio`. +//! +//! This crate provides synchronous and asynchronous helpers for TCP, UDP and +//! ICMP. The goal is to simplify lower level socket configuration across +//! platforms while still allowing direct access when needed. -pub use socket::AsyncSocket; -pub use socket::AsyncTcpStream; -pub use socket::IpVersion; -pub use socket::Socket; -pub use socket::SocketOption; -pub use socket::SocketType; -pub use sys::PacketReceiver; +pub mod icmp; +pub mod tcp; +pub mod udp; diff --git a/nex-socket/src/socket/async_impl.rs b/nex-socket/src/socket/async_impl.rs deleted file mode 100644 index a16bae8..0000000 --- a/nex-socket/src/socket/async_impl.rs +++ /dev/null @@ -1,743 +0,0 @@ -use crate::socket::to_socket_protocol; -use crate::socket::{IpVersion, SocketOption}; -use async_io::{Async, Timer}; -use futures_lite::future::FutureExt; -use socket2::{SockAddr, Socket as SystemSocket}; -use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; -use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket}; -use std::sync::Arc; -use std::time::Duration; - -/// Async socket. Provides cross-platform async adapter for system socket. -#[derive(Clone, Debug)] -pub struct AsyncSocket { - inner: Arc>, -} - -impl AsyncSocket { - /// Constructs a new AsyncSocket. - pub fn new(socket_option: SocketOption) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with async non-blocking TCP connect. - pub async fn new_with_async_connect(addr: &SocketAddr) -> io::Result { - let stream = Async::::connect(*addr).await?; - // Once the connection is established, we can turn it into a SystemSocket(socket2::Socket). - // And then we can turn it into a AsyncSocket for the rest of the operations. - let socket = SystemSocket::from(stream.into_inner()?); - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with async non-blocking TCP connect and timeout. - pub async fn new_with_async_connect_timeout( - addr: &SocketAddr, - timeout: Duration, - ) -> io::Result { - let stream = Async::::connect(*addr) - .or(async { - Timer::after(timeout).await; - Err(io::ErrorKind::TimedOut.into()) - }) - .await?; - // Once the connection is established, we can turn it into a SystemSocket(socket2::Socket). - // And then we can turn it into a AsyncSocket for the rest of the operations. - let socket = SystemSocket::from(stream.into_inner()?); - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with TCP connect. - /// If you want to async non-blocking connect, use `new_with_async_connect` instead. - pub fn new_with_connect( - socket_option: SocketOption, - addr: &SocketAddr, - ) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - let addr: SockAddr = SockAddr::from(*addr); - socket.connect(&addr)?; - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with TCP connect and timeout. - /// If you want to async non-blocking connect, use `new_with_async_connect_timeout` instead. - pub fn new_with_connect_timeout( - socket_option: SocketOption, - addr: &SocketAddr, - timeout: Duration, - ) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - let addr: SockAddr = SockAddr::from(*addr); - socket.connect_timeout(&addr, timeout)?; - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with listener. - pub fn new_with_listener( - socket_option: SocketOption, - addr: &SocketAddr, - ) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - socket.set_nonblocking(true)?; - let addr: SockAddr = SockAddr::from(*addr); - socket.bind(&addr)?; - socket.listen(1024)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket with bind. - pub fn new_with_bind( - socket_option: SocketOption, - addr: &SocketAddr, - ) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - socket.set_nonblocking(true)?; - let addr: SockAddr = SockAddr::from(*addr); - socket.bind(&addr)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket from TcpStream. - /// Async Socket does not support non-blocking connect. Use TCP Stream to connect to the target. - pub fn from_tcp_stream(tcp_stream: TcpStream) -> io::Result { - let socket = SystemSocket::from(tcp_stream); - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket from TcpListener. - pub fn from_tcp_listener(tcp_listener: TcpListener) -> io::Result { - let socket = SystemSocket::from(tcp_listener); - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Constructs a new AsyncSocket from UdpSocket. - pub fn from_udp_socket(udp_socket: UdpSocket) -> io::Result { - let socket = SystemSocket::from(udp_socket); - socket.set_nonblocking(true)?; - Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }) - } - /// Bind socket to address. - pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(addr); - //self.inner.writable().await?; - self.inner.write_with(|inner| inner.bind(&addr)).await - } - /// Send packet. - pub async fn send(&self, buf: &[u8]) -> io::Result { - loop { - self.inner.writable().await?; - match self.inner.write_with(|inner| inner.send(buf)).await { - Ok(n) => return Ok(n), - Err(_) => continue, - } - } - } - /// Send packet to target. - pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { - let target: SockAddr = SockAddr::from(target); - loop { - self.inner.writable().await?; - match self - .inner - .write_with(|inner| inner.send_to(buf, &target)) - .await - { - Ok(n) => return Ok(n), - Err(_) => continue, - } - } - } - /// Receive packet. - pub async fn receive(&self, buf: &mut Vec) -> io::Result { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - loop { - self.inner.readable().await?; - match self.inner.read_with(|inner| inner.recv(recv_buf)).await { - Ok(result) => return Ok(result), - Err(_) => continue, - } - } - } - /// Receive packet with sender address. - pub async fn receive_from(&self, buf: &mut Vec) -> io::Result<(usize, SocketAddr)> { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - loop { - self.inner.readable().await?; - match self - .inner - .read_with(|inner| inner.recv_from(recv_buf)) - .await - { - Ok(result) => { - let (n, addr) = result; - match addr.as_socket() { - Some(addr) => return Ok((n, addr)), - None => continue, - } - } - Err(_) => continue, - } - } - } - /// Write data to the socket and send to the target. - /// Return how many bytes were written. - pub async fn write(&self, buf: &[u8]) -> io::Result { - loop { - self.inner.writable().await?; - match self.inner.write_with(|inner| inner.send(buf)).await { - Ok(n) => return Ok(n), - Err(_) => continue, - } - } - } - /// Write data with timeout. - /// Return how many bytes were written. - pub async fn write_timeout(&self, buf: &[u8], timeout: Duration) -> io::Result { - loop { - self.inner.writable().await?; - match self - .inner - .write_with(|inner| { - match inner.set_write_timeout(Some(timeout)) { - Ok(_) => {} - Err(e) => return Err(e), - } - inner.send(buf) - }) - .await - { - Ok(n) => return Ok(n), - Err(_) => continue, - } - } - } - /// Read data from the socket. - /// Return how many bytes were read. - pub async fn read(&self, buf: &mut Vec) -> io::Result { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - loop { - self.inner.readable().await?; - match self.inner.read_with(|inner| inner.recv(recv_buf)).await { - Ok(result) => return Ok(result), - Err(_) => continue, - } - } - } - /// Read data with timeout. - /// Return how many bytes were read. - pub async fn read_timeout(&self, buf: &mut Vec, timeout: Duration) -> io::Result { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - loop { - self.inner.readable().await?; - match self - .inner - .read_with(|inner| { - match inner.set_read_timeout(Some(timeout)) { - Ok(_) => {} - Err(e) => return Err(e), - } - inner.recv(recv_buf) - }) - .await - { - Ok(result) => return Ok(result), - Err(_) => continue, - } - } - } - /// Get TTL or Hop Limit. - pub async fn ttl(&self, ip_version: IpVersion) -> io::Result { - match ip_version { - IpVersion::V4 => self.inner.read_with(|inner| inner.ttl()).await, - IpVersion::V6 => self.inner.read_with(|inner| inner.unicast_hops_v6()).await, - } - } - /// Set TTL or Hop Limit. - pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> { - match ip_version { - IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await, - IpVersion::V6 => { - self.inner - .write_with(|inner| inner.set_unicast_hops_v6(ttl)) - .await - } - } - } - /// Get the value of the IP_TOS option for this socket. - pub async fn tos(&self) -> io::Result { - self.inner.read_with(|inner| inner.tos()).await - } - /// Set the value of the IP_TOS option for this socket. - pub async fn set_tos(&self, tos: u32) -> io::Result<()> { - self.inner.write_with(|inner| inner.set_tos(tos)).await - } - /// Get the value of the IP_RECVTOS option for this socket. - pub async fn receive_tos(&self) -> io::Result { - self.inner.read_with(|inner| inner.recv_tos()).await - } - /// Set the value of the IP_RECVTOS option for this socket. - pub async fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_recv_tos(receive_tos)) - .await - } - /// Initiate TCP connection. - pub async fn connect(&mut self, addr: &SocketAddr) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(*addr); - self.inner.write_with(|inner| inner.connect(&addr)).await - } - /// Initiate a connection on this socket to the specified address, only only waiting for a certain period of time for the connection to be established. - /// The non-blocking state of the socket is overridden by this function. - pub async fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(*addr); - self.inner - .write_with(|inner| inner.connect_timeout(&addr, timeout)) - .await - } - /// Listen TCP connection. - pub async fn listen(&self, backlog: i32) -> io::Result<()> { - self.inner.write_with(|inner| inner.listen(backlog)).await - } - /// Accept TCP connection. - pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> { - match self.inner.read_with(|inner| inner.accept()).await { - Ok((socket, addr)) => { - let socket = AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }; - Ok((socket, addr.as_socket().unwrap())) - } - Err(e) => Err(e), - } - } - /// Get local address. - pub async fn local_addr(&self) -> io::Result { - match self.inner.read_with(|inner| inner.local_addr()).await { - Ok(addr) => Ok(addr.as_socket().unwrap()), - Err(e) => Err(e), - } - } - /// Get peer address. - pub async fn peer_addr(&self) -> io::Result { - match self.inner.read_with(|inner| inner.peer_addr()).await { - Ok(addr) => Ok(addr.as_socket().unwrap()), - Err(e) => Err(e), - } - } - /// Get type of the socket. - pub async fn socket_type(&self) -> io::Result { - match self.inner.read_with(|inner| inner.r#type()).await { - Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)), - Err(e) => Err(e), - } - } - /// Create a new socket with the same configuration and bound to the same address. - pub async fn try_clone(&self) -> io::Result { - match self.inner.read_with(|inner| inner.try_clone()).await { - Ok(socket) => Ok(AsyncSocket { - inner: Arc::new(Async::new(socket)?), - }), - Err(e) => Err(e), - } - } - - /// Returns true if this socket is set to nonblocking mode, false otherwise. - #[cfg(not(target_os = "windows"))] - pub async fn is_nonblocking(&self) -> io::Result { - self.inner.read_with(|inner| inner.nonblocking()).await - } - /// Set non-blocking mode. - pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_nonblocking(nonblocking)) - .await - } - /// Shutdown TCP connection. - pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.inner.write_with(|inner| inner.shutdown(how)).await - } - /// Get the value of the SO_BROADCAST option for this socket. - pub async fn is_broadcast(&self) -> io::Result { - self.inner.read_with(|inner| inner.broadcast()).await - } - /// Set the value of the `SO_BROADCAST` option for this socket. - /// - /// When enabled, this socket is allowed to send packets to a broadcast address. - pub async fn set_broadcast(&self, broadcast: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_broadcast(broadcast)) - .await - } - /// Get the value of the `SO_ERROR` option on this socket. - pub async fn get_error(&self) -> io::Result> { - self.inner.read_with(|inner| inner.take_error()).await - } - /// Get the value of the `SO_KEEPALIVE` option on this socket. - pub async fn is_keepalive(&self) -> io::Result { - self.inner.read_with(|inner| inner.keepalive()).await - } - /// Set value for the `SO_KEEPALIVE` option on this socket. - /// - /// Enable sending of keep-alive messages on connection-oriented sockets. - pub async fn set_keepalive(&self, keepalive: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_keepalive(keepalive)) - .await - } - /// Get the value of the SO_LINGER option on this socket. - pub async fn linger(&self) -> io::Result> { - self.inner.read_with(|inner| inner.linger()).await - } - /// Set value for the SO_LINGER option on this socket. - pub async fn set_linger(&self, dur: Option) -> io::Result<()> { - self.inner.write_with(|inner| inner.set_linger(dur)).await - } - /// Get the value of the `SO_RCVBUF` option on this socket. - pub async fn receive_buffer_size(&self) -> io::Result { - self.inner.read_with(|inner| inner.recv_buffer_size()).await - } - /// Set value for the `SO_RCVBUF` option on this socket. - /// - /// Changes the size of the operating system's receive buffer associated with the socket. - pub async fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_recv_buffer_size(size)) - .await - } - /// Get value for the SO_RCVTIMEO option on this socket. - pub async fn receive_timeout(&self) -> io::Result> { - self.inner.read_with(|inner| inner.read_timeout()).await - } - /// Set value for the `SO_RCVTIMEO` option on this socket. - pub async fn set_receive_timeout(&self, duration: Option) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_read_timeout(duration)) - .await - } - /// Get value for the `SO_REUSEADDR` option on this socket. - pub async fn reuse_address(&self) -> io::Result { - self.inner.read_with(|inner| inner.reuse_address()).await - } - /// Set value for the `SO_REUSEADDR` option on this socket. - /// - /// This indicates that futher calls to `bind` may allow reuse of local addresses. - pub async fn set_reuse_address(&self, reuse: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_reuse_address(reuse)) - .await - } - /// Get value for the `SO_SNDBUF` option on this socket. - pub async fn send_buffer_size(&self) -> io::Result { - self.inner.read_with(|inner| inner.send_buffer_size()).await - } - /// Set value for the `SO_SNDBUF` option on this socket. - /// - /// Changes the size of the operating system's send buffer associated with the socket. - pub async fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_send_buffer_size(size)) - .await - } - /// Get value for the `SO_SNDTIMEO` option on this socket. - pub async fn send_timeout(&self) -> io::Result> { - self.inner.read_with(|inner| inner.write_timeout()).await - } - /// Set value for the `SO_SNDTIMEO` option on this socket. - /// - /// If `timeout` is `None`, then `write` and `send` calls will block indefinitely. - pub async fn set_send_timeout(&self, duration: Option) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_write_timeout(duration)) - .await - } - /// Get the value of the IP_HDRINCL option on this socket. - pub async fn is_ip_header_included(&self) -> io::Result { - self.inner.read_with(|inner| inner.header_included_v4()).await - } - /// Set the value of the `IP_HDRINCL` option on this socket. - pub async fn set_ip_header_included(&self, include: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_header_included_v4(include)) - .await - } - /// Get the value of the TCP_NODELAY option on this socket. - pub async fn is_nodelay(&self) -> io::Result { - self.inner.read_with(|inner| inner.nodelay()).await - } - /// Set the value of the `TCP_NODELAY` option on this socket. - /// - /// If set, segments are always sent as soon as possible, even if there is only a small amount of data. - pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_nodelay(nodelay)) - .await - } - /// Get TCP Stream - /// This function will consume the socket and return a new std::net::TcpStream. - pub fn into_tcp_stream(&self) -> io::Result { - let socket = Arc::try_unwrap(self.inner.clone()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))? - .into_inner()?; - let tcp_stream = TcpStream::from(socket); - Ok(tcp_stream) - } - /// Get TCP Listener - /// This function will consume the socket and return a new std::net::TcpListener. - pub fn into_tcp_listener(&self) -> io::Result { - let socket = Arc::try_unwrap(self.inner.clone()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))? - .into_inner()?; - let tcp_listener = TcpListener::from(socket); - Ok(tcp_listener) - } - /// Get UDP Socket - /// This function will consume the socket and return a new std::net::UdpSocket. - pub fn into_udp_socket(&self) -> io::Result { - let socket = Arc::try_unwrap(self.inner.clone()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))? - .into_inner()?; - let udp_socket = UdpSocket::from(socket); - Ok(udp_socket) - } -} - -/// Async TCP Stream. -#[derive(Clone, Debug)] -pub struct AsyncTcpStream { - inner: Arc>, -} - -impl AsyncTcpStream { - /// Connect to a remote address. - pub async fn connect(addr: SocketAddr) -> io::Result { - let stream = Async::::connect(addr).await?; - Ok(AsyncTcpStream { - inner: Arc::new(stream), - }) - } - - /// Connect to a remote address with timeout. - pub async fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { - let stream = Async::::connect(*addr) - .or(async { - Timer::after(timeout).await; - Err(std::io::ErrorKind::TimedOut.into()) - }) - .await?; - Ok(AsyncTcpStream { - inner: Arc::new(stream), - }) - } - - /// Get local address. - pub async fn local_addr(&self) -> io::Result { - self.inner.read_with(|inner| inner.local_addr()).await - } - - /// Get peer address. - pub async fn peer_addr(&self) -> io::Result { - self.inner.read_with(|inner| inner.peer_addr()).await - } - - /// Write data to the socket. - pub async fn write(&self, buf: &[u8]) -> io::Result { - self.inner.write_with(|mut inner| inner.write(buf)).await - } - - /// Attempts to write an entire buffer into this writer. - pub async fn write_all(&self, buf: &[u8]) -> io::Result<()> { - self.inner - .write_with(|mut inner| inner.write_all(buf)) - .await - } - - /// Read data from the socket. - pub async fn read(&self, buf: &mut [u8]) -> io::Result { - self.inner.read_with(|mut inner| inner.read(buf)).await - } - - /// Read all bytes until EOF in this source, placing them into buf. - pub async fn read_to_end(&self, buf: &mut Vec) -> io::Result { - self.inner - .read_with(|mut inner| inner.read_to_end(buf)) - .await - } - - /// Read all bytes until EOF in this source, placing them into buf. - /// This ignore io::Error on read_to_end because it is expected when reading response. - /// If no response is received, and io::Error is occurred, return Err. - pub async fn read_to_end_timeout( - &self, - buf: &mut Vec, - timeout: Duration, - ) -> io::Result { - let mut io_error: io::Error = io::Error::new(io::ErrorKind::Other, "No response"); - match self - .read_to_end(buf) - .or(async { - Timer::after(timeout).await; - Err(std::io::ErrorKind::TimedOut.into()) - }) - .await - { - Ok(_) => {} - Err(e) => { - io_error = e; - } - } - if buf.is_empty() { - Err(io_error) - } else { - Ok(buf.len()) - } - } - - /// Shutdown the socket. - pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.inner.write_with(|inner| inner.shutdown(how)).await - } - - /// Get the value of the `SO_ERROR` option on this socket. - pub async fn take_error(&self) -> io::Result> { - self.inner.read_with(|inner| inner.take_error()).await - } - /// Creates a new independently owned handle to the underlying socket. - pub async fn try_clone(&self) -> io::Result { - let stream = self.inner.read_with(|inner| inner.try_clone()).await?; - Ok(AsyncTcpStream { - inner: Arc::new(Async::new(stream)?), - }) - } - - /// Sets the read timeout to the timeout specified. - pub async fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_read_timeout(dur)) - .await - } - - /// Sets the write timeout to the timeout specified. - pub async fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_write_timeout(dur)) - .await - } - - /// Gets the read timeout of this socket. - pub async fn read_timeout(&self) -> io::Result> { - self.inner.read_with(|inner| inner.read_timeout()).await - } - - /// Gets the write timeout of this socket. - pub async fn write_timeout(&self) -> io::Result> { - self.inner.read_with(|inner| inner.write_timeout()).await - } - - /// Sets the value of the `TCP_NODELAY` option on this socket. - pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_nodelay(nodelay)) - .await - } - - /// Gets the value of the `TCP_NODELAY` option on this socket. - pub async fn nodelay(&self) -> io::Result { - self.inner.read_with(|inner| inner.nodelay()).await - } - - /// Sets the value for the IP_TTL option on this socket. - pub async fn set_ttl(&self, ttl: u32) -> io::Result<()> { - self.inner.write_with(|inner| inner.set_ttl(ttl)).await - } - - /// Gets the value of the IP_TTL option on this socket. - pub async fn ttl(&self) -> io::Result { - self.inner.read_with(|inner| inner.ttl()).await - } - - /// Moves this TCP stream into or out of nonblocking mode. - pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner - .write_with(|inner| inner.set_nonblocking(nonblocking)) - .await - } -} diff --git a/nex-socket/src/socket/mod.rs b/nex-socket/src/socket/mod.rs deleted file mode 100644 index 3410cb2..0000000 --- a/nex-socket/src/socket/mod.rs +++ /dev/null @@ -1,120 +0,0 @@ -mod async_impl; -mod sync_impl; - -use nex_packet::ip::IpNextLevelProtocol; -use socket2::{Domain, Type}; - -use crate::sys; - -pub use async_impl::*; -pub use sync_impl::*; - -/// IP version. IPv4 or IPv6. -#[derive(Clone, Debug)] -pub enum IpVersion { - V4, - V6, -} - -impl IpVersion { - /// IP Version number as u8. - pub fn version_u8(&self) -> u8 { - match self { - IpVersion::V4 => 4, - IpVersion::V6 => 6, - } - } - /// Return true if IP version is IPv4. - pub fn is_ipv4(&self) -> bool { - match self { - IpVersion::V4 => true, - IpVersion::V6 => false, - } - } - /// Return true if IP version is IPv6. - pub fn is_ipv6(&self) -> bool { - match self { - IpVersion::V4 => false, - IpVersion::V6 => true, - } - } - pub(crate) fn to_domain(&self) -> Domain { - match self { - IpVersion::V4 => Domain::IPV4, - IpVersion::V6 => Domain::IPV6, - } - } -} - -/// Socket type -#[derive(Clone, Debug)] -pub enum SocketType { - /// Raw socket - Raw, - /// Datagram socket. Usualy used for UDP. - Datagram, - /// Stream socket. Used for TCP. - Stream, -} - -impl SocketType { - pub(crate) fn to_type(&self) -> Type { - match self { - SocketType::Raw => Type::RAW, - SocketType::Datagram => Type::DGRAM, - SocketType::Stream => Type::STREAM, - } - } - pub(crate) fn from_type(t: Type) -> SocketType { - match t { - Type::RAW => SocketType::Raw, - Type::DGRAM => SocketType::Datagram, - Type::STREAM => SocketType::Stream, - _ => SocketType::Stream, - } - } -} - -/// Socket option. -#[derive(Clone, Debug)] -pub struct SocketOption { - /// IP version - pub ip_version: IpVersion, - /// Socket type - pub socket_type: SocketType, - /// Protocol. TCP, UDP, ICMP, etc. - pub protocol: Option, - /// Non-blocking mode - pub non_blocking: bool, -} - -impl SocketOption { - /// Constructs a new SocketOption. - pub fn new( - ip_version: IpVersion, - socket_type: SocketType, - protocol: Option, - ) -> SocketOption { - SocketOption { - ip_version, - socket_type, - protocol, - non_blocking: false, - } - } - /// Check socket option. - /// Return Ok(()) if socket option is valid. - pub fn is_valid(&self) -> Result<(), String> { - sys::check_socket_option(self.clone()) - } -} - -fn to_socket_protocol(protocol: IpNextLevelProtocol) -> socket2::Protocol { - match protocol { - IpNextLevelProtocol::Tcp => socket2::Protocol::TCP, - IpNextLevelProtocol::Udp => socket2::Protocol::UDP, - IpNextLevelProtocol::Icmp => socket2::Protocol::ICMPV4, - IpNextLevelProtocol::Icmpv6 => socket2::Protocol::ICMPV6, - _ => socket2::Protocol::TCP, - } -} diff --git a/nex-socket/src/socket/sync_impl.rs b/nex-socket/src/socket/sync_impl.rs deleted file mode 100644 index 90b0a9f..0000000 --- a/nex-socket/src/socket/sync_impl.rs +++ /dev/null @@ -1,389 +0,0 @@ -use crate::socket::to_socket_protocol; -use crate::socket::{IpVersion, SocketOption}; -use socket2::{SockAddr, Socket as SystemSocket}; -use std::io; -use std::mem::MaybeUninit; -use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket}; -use std::sync::Arc; -use std::time::Duration; - -/// Socket. Provides cross-platform adapter for system socket. -#[derive(Clone, Debug)] -pub struct Socket { - inner: Arc, -} - -impl Socket { - /// Constructs a new Socket. - pub fn new(socket_option: SocketOption) -> io::Result { - let socket: SystemSocket = if let Some(protocol) = socket_option.protocol { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - Some(to_socket_protocol(protocol)), - )? - } else { - SystemSocket::new( - socket_option.ip_version.to_domain(), - socket_option.socket_type.to_type(), - None, - )? - }; - if socket_option.non_blocking { - socket.set_nonblocking(true)?; - } - Ok(Socket { - inner: Arc::new(socket), - }) - } - /// Bind socket to address. - pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(addr); - self.inner.bind(&addr) - } - /// Send packet. - pub fn send(&self, buf: &[u8]) -> io::Result { - match self.inner.send(buf) { - Ok(n) => Ok(n), - Err(e) => Err(e), - } - } - /// Send packet to target. - pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { - let target: SockAddr = SockAddr::from(target); - match self.inner.send_to(buf, &target) { - Ok(n) => Ok(n), - Err(e) => Err(e), - } - } - /// Receive packet. - pub fn receive(&self, buf: &mut Vec) -> io::Result { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - match self.inner.recv(recv_buf) { - Ok(result) => Ok(result), - Err(e) => Err(e), - } - } - /// Receive packet with sender address. - pub fn receive_from(&self, buf: &mut Vec) -> io::Result<(usize, SocketAddr)> { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - match self.inner.recv_from(recv_buf) { - Ok(result) => { - let (n, addr) = result; - match addr.as_socket() { - Some(addr) => return Ok((n, addr)), - None => { - return Err(io::Error::new( - io::ErrorKind::Other, - "Invalid socket address", - )) - } - } - } - Err(e) => Err(e), - } - } - /// Write data to the socket and send to the target. - /// Return how many bytes were written. - pub fn write(&self, buf: &[u8]) -> io::Result { - match self.inner.send(buf) { - Ok(n) => Ok(n), - Err(e) => Err(e), - } - } - /// Attempts to write an entire buffer into this writer. - pub fn write_all(&self, buf: &[u8]) -> io::Result<()> { - let mut offset = 0; - while offset < buf.len() { - match self.inner.send(&buf[offset..]) { - Ok(n) => offset += n, - Err(e) => return Err(e), - } - } - Ok(()) - } - /// Read data from the socket. - /// Return how many bytes were read. - pub fn read(&self, buf: &mut Vec) -> io::Result { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - match self.inner.recv(recv_buf) { - Ok(result) => Ok(result), - Err(e) => Err(e), - } - } - /// Read all bytes until EOF in this source, placing them into buf. - pub fn read_to_end(&self, buf: &mut Vec) -> io::Result { - let mut total = 0; - loop { - let mut recv_buf = Vec::new(); - match self.receive(&mut recv_buf) { - Ok(n) => { - if n == 0 { - break; - } - total += n; - buf.extend_from_slice(&recv_buf[..n]); - } - Err(e) => return Err(e), - } - } - Ok(total) - } - /// Read all bytes until EOF in this source, placing them into buf. - /// This ignore io::Error on read_to_end because it is expected when reading response. - /// If no response is received, and io::Error is occurred, return Err. - pub fn read_to_end_timeout(&self, buf: &mut Vec, timeout: Duration) -> io::Result { - // Set timeout - self.inner.set_read_timeout(Some(timeout))?; - let mut total = 0; - loop { - let mut recv_buf = Vec::new(); - match self.receive(&mut recv_buf) { - Ok(n) => { - if n == 0 { - return Ok(total); - } - total += n; - buf.extend_from_slice(&recv_buf[..n]); - } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - return Ok(total); - } - return Err(e); - } - } - } - } - /// Get TTL or Hop Limit. - pub fn ttl(&self, ip_version: IpVersion) -> io::Result { - match ip_version { - IpVersion::V4 => self.inner.ttl(), - IpVersion::V6 => self.inner.unicast_hops_v6(), - } - } - /// Set TTL or Hop Limit. - pub fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> { - match ip_version { - IpVersion::V4 => self.inner.set_ttl(ttl), - IpVersion::V6 => self.inner.set_unicast_hops_v6(ttl), - } - } - /// Get the value of the IP_TOS option for this socket. - pub fn tos(&self) -> io::Result { - self.inner.tos() - } - /// Set the value of the IP_TOS option for this socket. - pub fn set_tos(&self, tos: u32) -> io::Result<()> { - self.inner.set_tos(tos) - } - /// Get the value of the IP_RECVTOS option for this socket. - pub fn receive_tos(&self) -> io::Result { - self.inner.recv_tos() - } - /// Set the value of the IP_RECVTOS option for this socket. - pub fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> { - self.inner.set_recv_tos(receive_tos) - } - /// Initiate TCP connection. - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(*addr); - self.inner.connect(&addr) - } - /// Initiate a connection on this socket to the specified address, only only waiting for a certain period of time for the connection to be established. - /// The non-blocking state of the socket is overridden by this function. - pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { - let addr: SockAddr = SockAddr::from(*addr); - self.inner.connect_timeout(&addr, timeout) - } - /// Listen TCP connection. - pub fn listen(&self, backlog: i32) -> io::Result<()> { - self.inner.listen(backlog) - } - /// Accept TCP connection. - pub fn accept(&self) -> io::Result<(Socket, SocketAddr)> { - match self.inner.accept() { - Ok((socket, addr)) => Ok(( - Socket { - inner: Arc::new(socket), - }, - addr.as_socket().unwrap(), - )), - Err(e) => Err(e), - } - } - /// Get local address. - pub fn local_addr(&self) -> io::Result { - match self.inner.local_addr() { - Ok(addr) => Ok(addr.as_socket().unwrap()), - Err(e) => Err(e), - } - } - /// Get peer address. - pub fn peer_addr(&self) -> io::Result { - match self.inner.peer_addr() { - Ok(addr) => Ok(addr.as_socket().unwrap()), - Err(e) => Err(e), - } - } - /// Get type of the socket. - pub fn socket_type(&self) -> io::Result { - match self.inner.r#type() { - Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)), - Err(e) => Err(e), - } - } - /// Create a new socket with the same configuration and bound to the same address. - pub fn try_clone(&self) -> io::Result { - match self.inner.try_clone() { - Ok(socket) => Ok(Socket { - inner: Arc::new(socket), - }), - Err(e) => Err(e), - } - } - /// Returns true if this socket is set to nonblocking mode, false otherwise. - #[cfg(not(target_os = "windows"))] - pub fn is_nonblocking(&self) -> io::Result { - self.inner.nonblocking() - } - /// Set non-blocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner.set_nonblocking(nonblocking) - } - /// Shutdown TCP connection. - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.inner.shutdown(how) - } - /// Get the value of the SO_BROADCAST option for this socket. - pub fn is_broadcast(&self) -> io::Result { - self.inner.broadcast() - } - /// Set the value of the `SO_BROADCAST` option for this socket. - /// - /// When enabled, this socket is allowed to send packets to a broadcast address. - pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> { - self.inner.set_broadcast(broadcast) - } - /// Get the value of the `SO_ERROR` option on this socket. - pub fn get_error(&self) -> io::Result> { - self.inner.take_error() - } - /// Get the value of the `SO_KEEPALIVE` option on this socket. - pub fn keepalive(&self) -> io::Result { - self.inner.keepalive() - } - /// Set value for the `SO_KEEPALIVE` option on this socket. - /// - /// Enable sending of keep-alive messages on connection-oriented sockets. - pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> { - self.inner.set_keepalive(keepalive) - } - /// Get the value of the SO_LINGER option on this socket. - pub fn linger(&self) -> io::Result> { - self.inner.linger() - } - /// Set value for the SO_LINGER option on this socket. - pub fn set_linger(&self, dur: Option) -> io::Result<()> { - self.inner.set_linger(dur) - } - /// Get the value of the `SO_RCVBUF` option on this socket. - pub fn receive_buffer_size(&self) -> io::Result { - self.inner.recv_buffer_size() - } - /// Set value for the `SO_RCVBUF` option on this socket. - /// - /// Changes the size of the operating system's receive buffer associated with the socket. - pub fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> { - self.inner.set_recv_buffer_size(size) - } - /// Get value for the SO_RCVTIMEO option on this socket. - pub fn receive_timeout(&self) -> io::Result> { - self.inner.read_timeout() - } - /// Set value for the `SO_RCVTIMEO` option on this socket. - pub fn set_receive_timeout(&self, duration: Option) -> io::Result<()> { - self.inner.set_read_timeout(duration) - } - /// Get value for the `SO_REUSEADDR` option on this socket. - pub fn reuse_address(&self) -> io::Result { - self.inner.reuse_address() - } - /// Set value for the `SO_REUSEADDR` option on this socket. - /// - /// This indicates that futher calls to `bind` may allow reuse of local addresses. - pub fn set_reuse_address(&self, reuse: bool) -> io::Result<()> { - self.inner.set_reuse_address(reuse) - } - /// Get value for the `SO_SNDBUF` option on this socket. - pub fn send_buffer_size(&self) -> io::Result { - self.inner.send_buffer_size() - } - /// Set value for the `SO_SNDBUF` option on this socket. - /// - /// Changes the size of the operating system's send buffer associated with the socket. - pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { - self.inner.set_send_buffer_size(size) - } - /// Get value for the `SO_SNDTIMEO` option on this socket. - pub fn send_timeout(&self) -> io::Result> { - self.inner.write_timeout() - } - /// Set value for the `SO_SNDTIMEO` option on this socket. - /// - /// If `timeout` is `None`, then `write` and `send` calls will block indefinitely. - pub fn set_send_timeout(&self, duration: Option) -> io::Result<()> { - self.inner.set_write_timeout(duration) - } - /// Get the value of the IP_HDRINCL option on this socket. - pub fn is_ip_header_included(&self) -> io::Result { - self.inner.header_included_v4() - } - /// Set the value of the `IP_HDRINCL` option on this socket. - pub fn set_ip_header_included(&self, include: bool) -> io::Result<()> { - self.inner.set_header_included_v4(include) - } - /// Get the value of the TCP_NODELAY option on this socket. - pub fn nodelay(&self) -> io::Result { - self.inner.nodelay() - } - /// Set the value of the `TCP_NODELAY` option on this socket. - /// - /// If set, segments are always sent as soon as possible, even if there is only a small amount of data. - pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { - self.inner.set_nodelay(nodelay) - } - /// Get TCP Stream - /// This function will consume the socket and return a new std::net::TcpStream. - pub fn into_tcp_stream(self) -> io::Result { - match Arc::try_unwrap(self.inner) { - Ok(socket) => Ok(socket.into()), - Err(_) => Err(io::Error::new( - io::ErrorKind::Other, - "Failed to unwrap socket", - )), - } - } - /// Get TCP Listener - /// This function will consume the socket and return a new std::net::TcpListener. - pub fn into_tcp_listener(self) -> io::Result { - match Arc::try_unwrap(self.inner) { - Ok(socket) => Ok(socket.into()), - Err(_) => Err(io::Error::new( - io::ErrorKind::Other, - "Failed to unwrap socket", - )), - } - } - /// Get UDP Socket - /// This function will consume the socket and return a new std::net::UdpSocket. - pub fn into_udp_socket(self) -> io::Result { - match Arc::try_unwrap(self.inner) { - Ok(socket) => Ok(socket.into()), - Err(_) => Err(io::Error::new( - io::ErrorKind::Other, - "Failed to unwrap socket", - )), - } - } -} diff --git a/nex-socket/src/sys/mod.rs b/nex-socket/src/sys/mod.rs deleted file mode 100644 index 39fe3c6..0000000 --- a/nex-socket/src/sys/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(not(target_os = "windows"))] -mod unix; -#[cfg(not(target_os = "windows"))] -pub use unix::*; - -#[cfg(target_os = "windows")] -mod windows; -#[cfg(target_os = "windows")] -pub use windows::*; diff --git a/nex-socket/src/sys/unix.rs b/nex-socket/src/sys/unix.rs deleted file mode 100644 index 84ac1b3..0000000 --- a/nex-socket/src/sys/unix.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::{io, mem::MaybeUninit, net::SocketAddr, time::Duration}; - -use nex_packet::ip::IpNextLevelProtocol; -use socket2::{Domain, Protocol, Socket as SystemSocket, Type}; - -use crate::socket::{IpVersion, SocketOption, SocketType}; - -pub(crate) fn check_socket_option(socket_option: SocketOption) -> Result<(), String> { - match socket_option.ip_version { - IpVersion::V4 => match socket_option.socket_type { - SocketType::Raw => match socket_option.protocol { - Some(IpNextLevelProtocol::Icmp) => Ok(()), - Some(IpNextLevelProtocol::Tcp) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - SocketType::Datagram => match socket_option.protocol { - Some(IpNextLevelProtocol::Icmp) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - SocketType::Stream => match socket_option.protocol { - Some(IpNextLevelProtocol::Tcp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - }, - IpVersion::V6 => match socket_option.socket_type { - SocketType::Raw => match socket_option.protocol { - Some(IpNextLevelProtocol::Icmpv6) => Ok(()), - Some(IpNextLevelProtocol::Tcp) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - SocketType::Datagram => match socket_option.protocol { - Some(IpNextLevelProtocol::Icmpv6) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - SocketType::Stream => match socket_option.protocol { - Some(IpNextLevelProtocol::Tcp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - }, - }, - } -} - -/// Receive all IPv4 or IPv6 packets passing through a network interface. -pub struct PacketReceiver { - inner: SystemSocket, -} - -impl PacketReceiver { - /// Constructs a new PacketReceiver. - pub fn new( - _socket_addr: SocketAddr, - ip_version: IpVersion, - protocol: Option, - timeout: Option, - ) -> io::Result { - let socket = match ip_version { - IpVersion::V4 => match protocol { - Some(IpNextLevelProtocol::Icmp) => { - SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))? - } - Some(IpNextLevelProtocol::Tcp) => { - SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::TCP))? - } - Some(IpNextLevelProtocol::Udp) => { - SystemSocket::new(Domain::IPV4, Type::RAW, Some(Protocol::UDP))? - } - _ => SystemSocket::new(Domain::IPV4, Type::RAW, None)?, - }, - IpVersion::V6 => match protocol { - Some(IpNextLevelProtocol::Icmpv6) => { - SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))? - } - Some(IpNextLevelProtocol::Tcp) => { - SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::TCP))? - } - Some(IpNextLevelProtocol::Udp) => { - SystemSocket::new(Domain::IPV6, Type::RAW, Some(Protocol::UDP))? - } - _ => SystemSocket::new(Domain::IPV6, Type::RAW, None)?, - }, - }; - if let Some(timeout) = timeout { - socket.set_read_timeout(Some(timeout))?; - } - //socket.bind(&socket_addr.into())?; - Ok(PacketReceiver { inner: socket }) - } - /// Receive packet without source address. - pub fn receive_from(&self, buf: &mut Vec) -> io::Result<(usize, SocketAddr)> { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - match self.inner.recv_from(recv_buf) { - Ok((packet_len, addr)) => match addr.as_socket() { - Some(socket_addr) => { - return Ok((packet_len, socket_addr)); - } - None => Err(io::Error::new( - io::ErrorKind::Other, - "Invalid socket address", - )), - }, - Err(e) => Err(e), - } - } -} diff --git a/nex-socket/src/sys/windows.rs b/nex-socket/src/sys/windows.rs deleted file mode 100644 index aee2e07..0000000 --- a/nex-socket/src/sys/windows.rs +++ /dev/null @@ -1,285 +0,0 @@ -use socket2::SockAddr; -use std::cmp::min; -use std::io; -use std::mem::{self, MaybeUninit}; -use std::net::{SocketAddr, UdpSocket}; -use std::ptr; -use std::sync::Once; -use std::time::Duration; - -#[allow(non_camel_case_types)] -type c_int = i32; - -#[allow(non_camel_case_types)] -type c_long = i32; - -type DWORD = u32; -use windows_sys::Win32::Networking::WinSock::SIO_RCVALL; -use windows_sys::Win32::System::Threading::INFINITE; - -#[allow(non_camel_case_types)] -type u_long = u32; - -use windows_sys::Win32::Networking::WinSock::{self as sock, SOCKET, WSA_FLAG_NO_HANDLE_INHERIT}; -use windows_sys::Win32::Networking::WinSock::{ - AF_INET, AF_INET6, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPV6, IPPROTO_TCP, - IPPROTO_UDP, -}; - -pub(crate) const NO_INHERIT: c_int = 1 << (c_int::BITS - 1); -pub(crate) const MAX_BUF_LEN: usize = ::max_value() as usize; - -use crate::socket::{IpVersion, SocketOption, SocketType}; -use nex_packet::ip::IpNextLevelProtocol; - -pub fn check_socket_option(socket_option: SocketOption) -> Result<(), String> { - match socket_option.ip_version { - IpVersion::V4 => { - match socket_option.socket_type { - SocketType::Raw => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Icmp) => Ok(()), - Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv4 raw socket on Windows(Due to Winsock2 limitation))")), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - SocketType::Datagram => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Icmp) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - SocketType::Stream => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Tcp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - } - } - IpVersion::V6 => { - match socket_option.socket_type { - SocketType::Raw => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Icmpv6) => Ok(()), - Some(IpNextLevelProtocol::Tcp) => Err(String::from("TCP is not supported on IPv6 raw socket on Windows(Due to Winsock2 limitation))")), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - SocketType::Datagram => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Icmpv6) => Ok(()), - Some(IpNextLevelProtocol::Udp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - SocketType::Stream => { - match socket_option.protocol { - Some(IpNextLevelProtocol::Tcp) => Ok(()), - _ => Err(String::from("Invalid protocol")), - } - } - } - } - } -} - -macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - #[allow(unused_unsafe)] - let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; -} - -pub(crate) fn init_socket() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - let _ = UdpSocket::bind("127.0.0.1:34254"); - }); -} - -pub(crate) fn ioctlsocket(socket: SOCKET, cmd: c_long, payload: &mut u_long) -> io::Result<()> { - syscall!( - ioctlsocket(socket, cmd, payload), - PartialEq::eq, - sock::SOCKET_ERROR - ) - .map(|_| ()) -} - -pub(crate) fn create_socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result { - init_socket(); - let flags = if ty & NO_INHERIT != 0 { - ty = ty & !NO_INHERIT; - WSA_FLAG_NO_HANDLE_INHERIT - } else { - 0 - }; - syscall!( - WSASocketW( - family, - ty, - protocol, - ptr::null_mut(), - 0, - sock::WSA_FLAG_OVERLAPPED | flags, - ), - PartialEq::eq, - sock::INVALID_SOCKET - ) -} - -pub(crate) fn bind(socket: SOCKET, addr: &SockAddr) -> io::Result<()> { - // Convert the SockAddr reference to a raw pointer, which is required by the Windows API, - // and pass it to the `bind` function to associate the socket with the provided address. - // This is necessary for compatibility with the `windows-sys` crate. - let sockaddr = addr.as_ptr() as *const _; - syscall!(bind(socket, sockaddr, addr.len()), PartialEq::ne, 0).map(|_| ()) -} - -#[allow(dead_code)] -pub(crate) fn set_nonblocking(socket: SOCKET, nonblocking: bool) -> io::Result<()> { - let mut nonblocking = nonblocking as u_long; - ioctlsocket(socket, sock::FIONBIO, &mut nonblocking) -} - -pub(crate) fn set_promiscuous(socket: SOCKET, promiscuous: bool) -> io::Result<()> { - let mut promiscuous = promiscuous as u_long; - ioctlsocket(socket, SIO_RCVALL as i32, &mut promiscuous) -} - -pub(crate) unsafe fn setsockopt( - socket: SOCKET, - level: c_int, - optname: i32, - optval: T, -) -> io::Result<()> { - syscall!( - setsockopt( - socket, - level as i32, - optname, - (&optval as *const T).cast(), - mem::size_of::() as c_int, - ), - PartialEq::eq, - sock::SOCKET_ERROR - ) - .map(|_| ()) -} - -pub(crate) fn into_ms(duration: Option) -> DWORD { - duration - .map(|duration| min(duration.as_millis(), INFINITE as u128) as DWORD) - .unwrap_or(0) -} - -pub(crate) fn set_timeout_opt( - fd: SOCKET, - level: c_int, - optname: c_int, - duration: Option, -) -> io::Result<()> { - let duration = into_ms(duration); - unsafe { setsockopt(fd, level, optname, duration) } -} - -pub(crate) fn recv_from( - socket: SOCKET, - buf: &mut [MaybeUninit], - flags: c_int, -) -> io::Result<(usize, SockAddr)> { - unsafe { - SockAddr::try_init(|storage, addrlen| { - let res = syscall!( - recvfrom( - socket, - buf.as_mut_ptr().cast(), - min(buf.len(), MAX_BUF_LEN) as c_int, - flags, - storage.cast(), - addrlen, - ), - PartialEq::eq, - sock::SOCKET_ERROR - ); - match res { - Ok(n) => Ok(n as usize), - Err(ref err) if err.raw_os_error() == Some(sock::WSAESHUTDOWN as i32) => Ok(0), - Err(err) => Err(err), - } - }) - } -} - -/// Receive all IPv4 or IPv6 packets passing through a network interface. -pub struct PacketReceiver { - inner: SOCKET, -} - -impl PacketReceiver { - pub fn new( - socket_addr: SocketAddr, - ip_version: IpVersion, - protocol: Option, - timeout: Option, - ) -> io::Result { - let socket = match ip_version { - IpVersion::V4 => match protocol { - Some(IpNextLevelProtocol::Icmp) => { - create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_ICMP)? - } - Some(IpNextLevelProtocol::Tcp) => { - create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_TCP)? - } - Some(IpNextLevelProtocol::Udp) => { - create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_UDP)? - } - _ => create_socket(AF_INET as i32, sock::SOCK_RAW, IPPROTO_IP)?, - }, - IpVersion::V6 => match protocol { - Some(IpNextLevelProtocol::Icmpv6) => { - create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_ICMPV6)? - } - Some(IpNextLevelProtocol::Tcp) => { - create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_TCP)? - } - Some(IpNextLevelProtocol::Udp) => { - create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_UDP)? - } - _ => create_socket(AF_INET6 as i32, sock::SOCK_RAW, IPPROTO_IPV6)?, - }, - }; - let sock_addr = SockAddr::from(socket_addr); - bind(socket, &sock_addr)?; - set_promiscuous(socket, true)?; - set_timeout_opt(socket, sock::SOL_SOCKET, sock::SO_RCVTIMEO, timeout)?; - Ok(PacketReceiver { inner: socket }) - } - pub fn bind(&self, addr: &SockAddr) -> io::Result<()> { - bind(self.inner, addr) - } - pub fn receive_from(&self, buf: &mut Vec) -> io::Result<(usize, SocketAddr)> { - let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit]) }; - match recv_from(self.inner, recv_buf, 0) { - Ok((n, addr)) => match addr.as_socket() { - Some(socket_addr) => { - return Ok((n, socket_addr)); - } - None => Err(io::Error::new( - io::ErrorKind::Other, - "Invalid socket address", - )), - }, - Err(e) => Err(e), - } - } -} diff --git a/nex-socket/src/tcp/async_impl.rs b/nex-socket/src/tcp/async_impl.rs new file mode 100644 index 0000000..164d91c --- /dev/null +++ b/nex-socket/src/tcp/async_impl.rs @@ -0,0 +1,191 @@ +use crate::tcp::TcpConfig; +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, TcpStream as StdTcpStream, TcpListener as StdTcpListener}; +use std::time::Duration; +use tokio::net::{TcpListener, TcpStream}; + +/// Asynchronous TCP socket built on top of Tokio. +#[derive(Debug)] +pub struct AsyncTcpSocket { + socket: Socket, +} + +impl AsyncTcpSocket { + /// Create a socket from the given configuration without connecting. + pub fn from_config(config: &TcpConfig) -> io::Result { + let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?; + + if let Some(flag) = config.reuseaddr { + socket.set_reuse_address(flag)?; + } + if let Some(flag) = config.nodelay { + socket.set_nodelay(flag)?; + } + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(iface) = &config.bind_device { + socket.bind_device(Some(iface.as_bytes()))?; + } + + if let Some(addr) = config.bind_addr { + socket.bind(&addr.into())?; + } + + socket.set_nonblocking(true)?; + + Ok(Self { socket }) + } + + /// Create a socket of arbitrary type (STREAM or RAW). + pub fn new(domain: Domain, sock_type: SockType) -> io::Result { + let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?; + socket.set_nonblocking(true)?; + Ok(Self { socket }) + } + + /// Convenience constructor for an IPv4 STREAM socket. + pub fn v4_stream() -> io::Result { + Self::new(Domain::IPV4, SockType::STREAM) + } + + /// Convenience constructor for an IPv6 STREAM socket. + pub fn v6_stream() -> io::Result { + Self::new(Domain::IPV6, SockType::STREAM) + } + + /// IPv4 RAW TCP. Requires administrator privileges. + pub fn raw_v4() -> io::Result { + Self::new(Domain::IPV4, SockType::RAW) + } + + /// IPv6 RAW TCP. Requires administrator privileges. + pub fn raw_v6() -> io::Result { + Self::new(Domain::IPV6, SockType::RAW) + } + + /// Connect to the target asynchronously. + pub async fn connect(self, target: SocketAddr) -> io::Result { + // call connect + match self.socket.connect(&target.into()) { + Ok(_) => { + // connection completed immediately (rare case) + let std_stream: StdTcpStream = self.socket.into(); + return TcpStream::from_std(std_stream); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(libc::EINPROGRESS) => { + // wait until writable + let std_stream: StdTcpStream = self.socket.into(); + let stream = TcpStream::from_std(std_stream)?; + stream.writable().await?; + + // check the final connection state with SO_ERROR + if let Some(err) = stream.take_error()? { + return Err(err); + } + + return Ok(stream); + } + Err(e) => { + println!("Failed to connect: {}", e); + return Err(e); + } + } + } + + /// Connect with a timeout to the target address. + pub async fn connect_timeout(self, target: SocketAddr, timeout: Duration) -> io::Result { + match tokio::time::timeout(timeout, self.connect(target)).await { + Ok(result) => result, + Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")), + } + } + + /// Start listening for incoming connections. + pub fn listen(self, backlog: i32) -> io::Result { + self.socket.listen(backlog)?; + + let std_listener: StdTcpListener = self.socket.into(); + TcpListener::from_std(std_listener) + } + + /// Send a raw TCP packet. Requires `SockType::RAW`. + pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.socket.send_to(buf, &target.into()) + } + + /// Receive a raw TCP packet. Requires `SockType::RAW`. + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + // Safety: `MaybeUninit` has the same memory layout as `u8`. + let buf_maybe = unsafe { + std::slice::from_raw_parts_mut( + buf.as_mut_ptr() as *mut std::mem::MaybeUninit, + buf.len(), + ) + }; + + let (n, addr) = self.socket.recv_from(buf_maybe)?; + let addr = addr.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "invalid address format") + })?; + + Ok((n, addr)) + } + + // --- option helpers --- + + pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_address(on) + } + + pub fn set_nodelay(&self, on: bool) -> io::Result<()> { + self.socket.set_nodelay(on) + } + + pub fn set_linger(&self, dur: Option) -> io::Result<()> { + self.socket.set_linger(dur) + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.socket.set_ttl(ttl) + } + + pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + return self.socket.bind_device(Some(iface.as_bytes())); + + #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))] + { + let _ = iface; + Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS")) + } + } + + /// Retrieve the local address of the socket. + pub fn local_addr(&self) -> io::Result { + self.socket.local_addr()?.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Failed to get socket address") + }) + } + + /// Convert the internal socket into a Tokio `TcpStream`. + pub fn into_tokio_stream(self) -> io::Result { + let std_stream: StdTcpStream = self.socket.into(); + TcpStream::from_std(std_stream) + } + + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.socket.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.socket.as_raw_socket() + } +} diff --git a/nex-socket/src/tcp/config.rs b/nex-socket/src/tcp/config.rs new file mode 100644 index 0000000..8946994 --- /dev/null +++ b/nex-socket/src/tcp/config.rs @@ -0,0 +1,122 @@ +use socket2::{Domain, Type as SockType}; +use std::net::SocketAddr; +use std::time::Duration; + +/// Configuration options for a TCP socket. +#[derive(Debug, Clone)] +pub struct TcpConfig { + pub domain: Domain, + pub sock_type: SockType, + pub bind_addr: Option, + pub nonblocking: bool, + pub reuseaddr: Option, + pub nodelay: Option, + pub linger: Option, + pub ttl: Option, + pub bind_device: Option, +} + +impl TcpConfig { + /// Create a STREAM socket for IPv4. + pub fn v4_stream() -> Self { + Self { + domain: Domain::IPV4, + sock_type: SockType::STREAM, + bind_addr: None, + nonblocking: false, + reuseaddr: None, + nodelay: None, + linger: None, + ttl: None, + bind_device: None, + } + } + + /// Create a RAW socket. Requires administrator privileges. + pub fn raw_v4() -> Self { + Self { + domain: Domain::IPV4, + sock_type: SockType::RAW, + ..Self::v4_stream() + } + } + + /// Create a STREAM socket for IPv6. + pub fn v6_stream() -> Self { + Self { + domain: Domain::IPV6, + sock_type: SockType::STREAM, + ..Self::v4_stream() + } + } + + /// Create a RAW socket for IPv6. Requires administrator privileges. + pub fn raw_v6() -> Self { + Self { + domain: Domain::IPV6, + sock_type: SockType::RAW, + ..Self::v4_stream() + } + } + + // --- chainable modifiers --- + + pub fn with_bind(mut self, addr: SocketAddr) -> Self { + self.bind_addr = Some(addr); + self + } + + pub fn with_nonblocking(mut self, flag: bool) -> Self { + self.nonblocking = flag; + self + } + + pub fn with_reuseaddr(mut self, flag: bool) -> Self { + self.reuseaddr = Some(flag); + self + } + + pub fn with_nodelay(mut self, flag: bool) -> Self { + self.nodelay = Some(flag); + self + } + + pub fn with_linger(mut self, dur: Duration) -> Self { + self.linger = Some(dur); + self + } + + pub fn with_ttl(mut self, ttl: u32) -> Self { + self.ttl = Some(ttl); + self + } + + pub fn with_bind_device(mut self, iface: impl Into) -> Self { + self.bind_device = Some(iface.into()); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tcp_config_builders() { + let addr: SocketAddr = "127.0.0.1:80".parse().unwrap(); + let cfg = TcpConfig::v4_stream() + .with_bind(addr) + .with_nonblocking(true) + .with_reuseaddr(true) + .with_nodelay(true) + .with_ttl(10); + + assert_eq!(cfg.domain, Domain::IPV4); + assert_eq!(cfg.sock_type, SockType::STREAM); + assert_eq!(cfg.bind_addr, Some(addr)); + assert!(cfg.nonblocking); + assert_eq!(cfg.reuseaddr, Some(true)); + assert_eq!(cfg.nodelay, Some(true)); + assert_eq!(cfg.ttl, Some(10)); + } +} diff --git a/nex-socket/src/tcp/mod.rs b/nex-socket/src/tcp/mod.rs new file mode 100644 index 0000000..043548c --- /dev/null +++ b/nex-socket/src/tcp/mod.rs @@ -0,0 +1,7 @@ +mod config; +mod async_impl; +mod sync_impl; + +pub use config::*; +pub use sync_impl::*; +pub use async_impl::*; diff --git a/nex-socket/src/tcp/sync_impl.rs b/nex-socket/src/tcp/sync_impl.rs new file mode 100644 index 0000000..308bb82 --- /dev/null +++ b/nex-socket/src/tcp/sync_impl.rs @@ -0,0 +1,280 @@ +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, TcpStream, TcpListener}; +use std::time::Duration; + +use crate::tcp::TcpConfig; + +#[cfg(unix)] +use std::os::fd::AsRawFd; + +#[cfg(unix)] +use nix::poll::{poll, PollFd, PollFlags}; + +/// Low level synchronous TCP socket. +#[derive(Debug)] +pub struct TcpSocket { + socket: Socket, +} + +impl TcpSocket { + /// Build a socket according to `TcpSocketConfig`. + pub fn from_config(config: &TcpConfig) -> io::Result { + let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?; + + // Apply all configuration options + if let Some(flag) = config.reuseaddr { + socket.set_reuse_address(flag)?; + } + if let Some(flag) = config.nodelay { + socket.set_nodelay(flag)?; + } + if let Some(dur) = config.linger { + socket.set_linger(Some(dur))?; + } + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(iface) = &config.bind_device { + socket.bind_device(Some(iface.as_bytes()))?; + } + + // Bind to the specified address if provided + if let Some(addr) = config.bind_addr { + socket.bind(&addr.into())?; + } + + // Set non blocking mode + socket.set_nonblocking(config.nonblocking)?; + + Ok(Self { socket }) + } + + /// Create a socket of arbitrary type (STREAM or RAW). + pub fn new(domain: Domain, sock_type: SockType) -> io::Result { + let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?; + socket.set_nonblocking(false)?; + Ok(Self { socket }) + } + + /// Convenience constructor for an IPv4 STREAM socket. + pub fn v4_stream() -> io::Result { + Self::new(Domain::IPV4, SockType::STREAM) + } + + /// Convenience constructor for an IPv6 STREAM socket. + pub fn v6_stream() -> io::Result { + Self::new(Domain::IPV6, SockType::STREAM) + } + + /// IPv4 RAW TCP. Requires administrator privileges. + pub fn raw_v4() -> io::Result { + Self::new(Domain::IPV4, SockType::RAW) + } + + /// IPv6 RAW TCP. Requires administrator privileges. + pub fn raw_v6() -> io::Result { + Self::new(Domain::IPV6, SockType::RAW) + } + + // --- socket operations --- + + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.socket.bind(&addr.into()) + } + + pub fn connect(&self, addr: SocketAddr) -> io::Result<()> { + self.socket.connect(&addr.into()) + } + + #[cfg(unix)] + pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result { + let raw_fd = self.socket.as_raw_fd(); + self.socket.set_nonblocking(true)?; + + // Try to connect first + match self.socket.connect(&target.into()) { + Ok(_) => { /* succeeded immediately */ } + Err(err) if err.kind() == io::ErrorKind::WouldBlock || err.raw_os_error() == Some(libc::EINPROGRESS) => { + // Continue waiting + } + Err(e) => return Err(e), + } + + // Wait for the connection using poll + let timeout_ms = timeout.as_millis() as i32; + use std::os::unix::io::BorrowedFd; + // Safety: raw_fd is valid for the lifetime of this scope + let mut fds = [PollFd::new(unsafe { BorrowedFd::borrow_raw(raw_fd) }, PollFlags::POLLOUT)]; + let n = poll(&mut fds, Some(timeout_ms as u16))?; + + if n == 0 { + return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")); + } + + // Check the result with `SO_ERROR` + let err: i32 = self.socket.take_error()?.map(|e| e.raw_os_error().unwrap_or(0)).unwrap_or(0); + if err != 0 { + return Err(io::Error::from_raw_os_error(err)); + } + + self.socket.set_nonblocking(false)?; + + match self.socket.try_clone() { + Ok(cloned_socket) => { + // Convert the socket into a `std::net::TcpStream` + let std_stream: TcpStream = cloned_socket.into(); + Ok(std_stream) + } + Err(e) => Err(e), + } + } + + #[cfg(windows)] + pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result { + use std::os::windows::io::AsRawSocket; + use windows_sys::Win32::Networking::WinSock::{ + WSAPOLLFD, WSAPoll, POLLWRNORM, SOCKET_ERROR, SO_ERROR, SOL_SOCKET, + getsockopt, SOCKET, + }; + use std::mem::size_of; + + let sock = self.socket.as_raw_socket() as SOCKET; + self.socket.set_nonblocking(true)?; + + // Start connect + match self.socket.connect(&target.into()) { + Ok(_) => { /* connection succeeded immediately */ } + Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {} + Err(e) => return Err(e), + } + + // Wait using WSAPoll until writable + let mut fds = [WSAPOLLFD { + fd: sock, + events: POLLWRNORM, + revents: 0, + }]; + + let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32; + let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) }; + if result == SOCKET_ERROR { + return Err(io::Error::last_os_error()); + } else if result == 0 { + return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")); + } + + // Check for errors via `SO_ERROR` + let mut so_error: i32 = 0; + let mut optlen = size_of::() as i32; + let ret = unsafe { + getsockopt( + sock, + SOL_SOCKET as i32, + SO_ERROR as i32, + &mut so_error as *mut _ as *mut _, + &mut optlen, + ) + }; + + if ret == SOCKET_ERROR || so_error != 0 { + return Err(io::Error::from_raw_os_error(so_error)); + } + + self.socket.set_nonblocking(false)?; + + let std_stream: TcpStream = self.socket.try_clone()?.into(); + Ok(std_stream) + } + + pub fn listen(&self, backlog: i32) -> io::Result<()> { + self.socket.listen(backlog) + } + + pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let (stream, addr) = self.socket.accept()?; + Ok((stream.into(), addr.as_socket().unwrap())) + } + + pub fn to_tcp_stream(self) -> io::Result { + Ok(self.socket.into()) + } + + pub fn to_tcp_listener(self) -> io::Result { + Ok(self.socket.into()) + } + + /// Send a raw packet (for RAW TCP use). + pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.socket.send_to(buf, &target.into()) + } + + /// Receive a raw packet (for RAW TCP use). + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + // Safety: `MaybeUninit` is layout-compatible with `u8`. + let buf_maybe = unsafe { + std::slice::from_raw_parts_mut( + buf.as_mut_ptr() as *mut std::mem::MaybeUninit, + buf.len(), + ) + }; + + let (n, addr) = self.socket.recv_from(buf_maybe)?; + let addr = addr.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "invalid address format") + })?; + + Ok((n, addr)) + } + + // --- option helpers --- + + pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> { + self.socket.set_reuse_address(on) + } + + pub fn set_nodelay(&self, on: bool) -> io::Result<()> { + self.socket.set_nodelay(on) + } + + pub fn set_linger(&self, dur: Option) -> io::Result<()> { + self.socket.set_linger(dur) + } + + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.socket.set_ttl(ttl) + } + + pub fn set_bind_device(&self, iface: &str) -> io::Result<()> { + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + return self.socket.bind_device(Some(iface.as_bytes())); + + #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))] + { + let _ = iface; + Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS")) + } + } + + // --- information helpers --- + + pub fn local_addr(&self) -> io::Result { + self.socket.local_addr()?.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address") + }) + } + + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.socket.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.socket.as_raw_socket() + } +} diff --git a/nex-socket/src/udp/async_impl.rs b/nex-socket/src/udp/async_impl.rs new file mode 100644 index 0000000..f685845 --- /dev/null +++ b/nex-socket/src/udp/async_impl.rs @@ -0,0 +1,115 @@ +use crate::udp::UdpConfig; +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, UdpSocket as StdUdpSocket}; +use tokio::net::UdpSocket; + +/// Asynchronous UDP socket built on top of Tokio. +#[derive(Debug)] +pub struct AsyncUdpSocket { + socket: Socket, +} + +impl AsyncUdpSocket { + /// Create an asynchronous UDP socket from the given configuration. + pub fn from_config(config: &UdpConfig) -> io::Result { + // Determine address family from the bind address + let domain = match config.bind_addr { + Some(SocketAddr::V4(_)) => Domain::IPV4, + Some(SocketAddr::V6(_)) => Domain::IPV6, + None => Domain::IPV4, // default + }; + + let socket = Socket::new(domain, SockType::DGRAM, Some(Protocol::UDP))?; + + if let Some(flag) = config.reuseaddr { + socket.set_reuse_address(flag)?; + } + + if let Some(flag) = config.broadcast { + socket.set_broadcast(flag)?; + } + + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(iface) = &config.bind_device { + socket.bind_device(Some(iface.as_bytes()))?; + } + + if let Some(addr) = config.bind_addr { + socket.bind(&addr.into())?; + } + + socket.set_nonblocking(true)?; + + Ok(Self { socket }) + } + + /// Create a socket of arbitrary type (DGRAM or RAW). + pub fn new(domain: Domain, sock_type: SockType) -> io::Result { + let socket = Socket::new(domain, sock_type, Some(Protocol::UDP))?; + socket.set_nonblocking(true)?; + Ok(Self { socket }) + } + + /// Convenience constructor for IPv4 DGRAM. + pub fn v4_dgram() -> io::Result { + Self::new(Domain::IPV4, SockType::DGRAM) + } + + /// Convenience constructor for IPv6 DGRAM. + pub fn v6_dgram() -> io::Result { + Self::new(Domain::IPV6, SockType::DGRAM) + } + + /// IPv4 RAW UDP. Requires administrator privileges. + pub fn raw_v4() -> io::Result { + Self::new(Domain::IPV4, SockType::RAW) + } + + /// IPv6 RAW UDP. Requires administrator privileges. + pub fn raw_v6() -> io::Result { + Self::new(Domain::IPV6, SockType::RAW) + } + + /// Send data asynchronously. + pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + let std_udp: StdUdpSocket = self.socket.try_clone()?.into(); + let udp_socket = UdpSocket::from_std(std_udp)?; + udp_socket.send_to(buf, target).await + } + + /// Receive data asynchronously. + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let std_udp: StdUdpSocket = self.socket.try_clone()?.into(); + let udp_socket = UdpSocket::from_std(std_udp)?; + udp_socket.recv_from(buf).await + } + + /// Retrieve the local socket address. + pub fn local_addr(&self) -> io::Result { + self.socket.local_addr()?.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Failed to get socket address") + }) + } + + pub fn into_tokio_socket(self) -> io::Result { + let std_socket: StdUdpSocket = self.socket.into(); + UdpSocket::from_std(std_socket) + } + + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.socket.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.socket.as_raw_socket() + } +} diff --git a/nex-socket/src/udp/config.rs b/nex-socket/src/udp/config.rs new file mode 100644 index 0000000..ce1a008 --- /dev/null +++ b/nex-socket/src/udp/config.rs @@ -0,0 +1,47 @@ +use std::net::SocketAddr; + +/// Configuration options for a UDP socket. +#[derive(Debug, Clone)] +pub struct UdpConfig { + /// Address to bind. If `None`, the operating system chooses the address. + pub bind_addr: Option, + + /// Enable address reuse (`SO_REUSEADDR`). + pub reuseaddr: Option, + + /// Allow broadcast (`SO_BROADCAST`). + pub broadcast: Option, + + /// Time to live value. + pub ttl: Option, + + /// Bind to a specific interface (Linux only). + pub bind_device: Option, +} + +impl Default for UdpConfig { + fn default() -> Self { + Self { + bind_addr: None, + reuseaddr: None, + broadcast: None, + ttl: None, + bind_device: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn udp_config_default_values() { + let cfg = UdpConfig::default(); + assert!(cfg.bind_addr.is_none()); + assert!(cfg.reuseaddr.is_none()); + assert!(cfg.broadcast.is_none()); + assert!(cfg.ttl.is_none()); + assert!(cfg.bind_device.is_none()); + } +} diff --git a/nex-socket/src/udp/mod.rs b/nex-socket/src/udp/mod.rs new file mode 100644 index 0000000..043548c --- /dev/null +++ b/nex-socket/src/udp/mod.rs @@ -0,0 +1,7 @@ +mod config; +mod async_impl; +mod sync_impl; + +pub use config::*; +pub use sync_impl::*; +pub use async_impl::*; diff --git a/nex-socket/src/udp/sync_impl.rs b/nex-socket/src/udp/sync_impl.rs new file mode 100644 index 0000000..7ddd8ff --- /dev/null +++ b/nex-socket/src/udp/sync_impl.rs @@ -0,0 +1,135 @@ +use crate::udp::UdpConfig; +use socket2::{Domain, Protocol, Socket, Type as SockType}; +use std::io; +use std::net::{SocketAddr, UdpSocket as StdUdpSocket}; + +/// Synchronous low level UDP socket. +#[derive(Debug)] +pub struct UdpSocket { + socket: Socket, +} + +impl UdpSocket { + /// Create a socket from the provided configuration. + pub fn from_config(config: &UdpConfig) -> io::Result { + // Determine address family from the bind address + let domain = match config.bind_addr { + Some(SocketAddr::V4(_)) => Domain::IPV4, + Some(SocketAddr::V6(_)) => Domain::IPV6, + None => Domain::IPV4, // default + }; + + let socket = Socket::new(domain, SockType::DGRAM, Some(Protocol::UDP))?; + + if let Some(flag) = config.reuseaddr { + socket.set_reuse_address(flag)?; + } + + if let Some(flag) = config.broadcast { + socket.set_broadcast(flag)?; + } + + if let Some(ttl) = config.ttl { + socket.set_ttl(ttl)?; + } + + #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))] + if let Some(iface) = &config.bind_device { + socket.bind_device(Some(iface.as_bytes()))?; + } + + if let Some(addr) = config.bind_addr { + socket.bind(&addr.into())?; + } + + socket.set_nonblocking(false)?; // blocking mode for sync usage + Ok(Self { socket }) + } + + /// Create a socket of arbitrary type (DGRAM or RAW). + pub fn new(domain: Domain, sock_type: SockType) -> io::Result { + let socket = Socket::new(domain, sock_type, Some(Protocol::UDP))?; + socket.set_nonblocking(false)?; + Ok(Self { socket }) + } + + /// Convenience constructor for IPv4 DGRAM. + pub fn v4_dgram() -> io::Result { + Self::new(Domain::IPV4, SockType::DGRAM) + } + + /// Convenience constructor for IPv6 DGRAM. + pub fn v6_dgram() -> io::Result { + Self::new(Domain::IPV6, SockType::DGRAM) + } + + /// IPv4 RAW UDP. Requires administrator privileges. + pub fn raw_v4() -> io::Result { + Self::new(Domain::IPV4, SockType::RAW) + } + + /// IPv6 RAW UDP. Requires administrator privileges. + pub fn raw_v6() -> io::Result { + Self::new(Domain::IPV6, SockType::RAW) + } + + /// Send data. + pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.socket.send_to(buf, &target.into()) + } + + /// Receive data. + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + // Safety: `MaybeUninit` has the same layout as `u8`. + let buf_maybe = unsafe { + std::slice::from_raw_parts_mut( + buf.as_mut_ptr() as *mut std::mem::MaybeUninit, + buf.len(), + ) + }; + + let (n, addr) = self.socket.recv_from(buf_maybe)?; + let addr = addr.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "invalid address format") + })?; + + Ok((n, addr)) + } + + /// Retrieve the local socket address. + pub fn local_addr(&self) -> io::Result { + self.socket.local_addr()?.as_socket().ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Failed to get socket address") + }) + } + + /// Convert into a raw `std::net::UdpSocket`. + pub fn to_std(self) -> io::Result { + Ok(self.socket.into()) + } + + #[cfg(unix)] + pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + use std::os::fd::AsRawFd; + self.socket.as_raw_fd() + } + + #[cfg(windows)] + pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + use std::os::windows::io::AsRawSocket; + self.socket.as_raw_socket() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn create_v4_socket() { + let sock = UdpSocket::v4_dgram().expect("create socket"); + let addr = sock.local_addr().expect("addr"); + assert!(addr.is_ipv4()); + } +} + diff --git a/nex-sys/src/lib.rs b/nex-sys/src/lib.rs index f788202..4784d0d 100644 --- a/nex-sys/src/lib.rs +++ b/nex-sys/src/lib.rs @@ -67,4 +67,4 @@ pub fn recv_from( } else { Ok(len as usize) } -} +} \ No newline at end of file diff --git a/nex-sys/src/unix.rs b/nex-sys/src/unix.rs index 9dac910..ba2240d 100644 --- a/nex-sys/src/unix.rs +++ b/nex-sys/src/unix.rs @@ -146,3 +146,29 @@ where fn errno() -> i32 { io::Error::last_os_error().raw_os_error().unwrap() } + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_timeval_round_trip() { + let dur = Duration::new(1, 500_000_000); + let tv = duration_to_timeval(dur); + assert_eq!(timeval_to_duration(tv), dur); + } + + #[test] + fn test_timespec_round_trip() { + let dur = Duration::new(2, 123_456_789); + let ts = duration_to_timespec(dur); + assert_eq!(timespec_to_duration(ts), dur); + } + + #[test] + fn test_ipv4_addr_int() { + let addr = InAddr { s_addr: u32::from_be(0x7f000001) }; + assert_eq!(ipv4_addr_int(addr), 0x7f000001); + } +} diff --git a/nex/Cargo.toml b/nex/Cargo.toml index 94141d9..17b74fe 100644 --- a/nex/Cargo.toml +++ b/nex/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "nex" version.workspace = true -edition = "2021" +edition.workspace = true authors.workspace = true description = "Cross-platform networking library in Rust" repository = "https://github.com/shellrow/nex" @@ -15,12 +15,14 @@ nex-core = { workspace = true } nex-packet = { workspace = true } nex-datalink = { workspace = true } nex-socket = { workspace = true } -nex-packet-builder = { workspace = true } [dev-dependencies] +bytes = { workspace = true } serde_json = "1.0" +rand = { workspace = true } async-io = "2.4" futures = "0.3" +tokio = { version = "1", features = ["rt", "rt-multi-thread", "signal", "macros"] } [features] pcap = ["nex-datalink/pcap"] @@ -30,6 +32,10 @@ serde = ["nex-core/serde", "nex-packet/serde", "nex-datalink/serde"] name = "dump" path = "../examples/dump.rs" +[[example]] +name = "parse_frame" +path = "../examples/parse_frame.rs" + [[example]] name = "arp" path = "../examples/arp.rs" @@ -38,10 +44,6 @@ path = "../examples/arp.rs" name = "ndp" path = "../examples/ndp.rs" -[[example]] -name = "parse_frame" -path = "../examples/parse_frame.rs" - [[example]] name = "icmp_ping" path = "../examples/icmp_ping.rs" @@ -55,22 +57,21 @@ name = "udp_ping" path = "../examples/udp_ping.rs" [[example]] -name = "list_interfaces" -path = "../examples/list_interfaces.rs" +name = "icmp_socket" +path = "../examples/icmp_socket.rs" [[example]] -name = "serialize" -path = "../examples/serialize.rs" -required-features = ["serde"] +name = "tcp_socket" +path = "../examples/tcp_socket.rs" [[example]] -name = "tcp_stream" -path = "../examples/tcp_stream.rs" +name = "udp_socket" +path = "../examples/udp_socket.rs" [[example]] -name = "async_tcp_connect" -path = "../examples/async_tcp_connect.rs" +name = "async_icmp_socket" +path = "../examples/async_icmp_socket.rs" [[example]] -name = "async_tcp_stream" -path = "../examples/async_tcp_stream.rs" +name = "async_tcp_socket" +path = "../examples/async_tcp_socket.rs" diff --git a/nex/src/lib.rs b/nex/src/lib.rs index 45dc783..f6f5ce5 100644 --- a/nex/src/lib.rs +++ b/nex/src/lib.rs @@ -1,3 +1,8 @@ +//! Entry point for the nex-next collection of crates. +//! +//! This crate re-exports the core modules so applications can simply depend on +//! `nex` and gain access to packet parsing, datalink channels and socket helpers. +//! It is intended to be a convenient facade for the underlying crates. /// Provides core network types and functionality. pub mod net { pub use nex_core::*; @@ -17,10 +22,3 @@ pub mod packet { pub mod socket { pub use nex_socket::*; } - -/// Utilities designed to work with packets through high-level APIs. -pub mod util { - pub mod packet_builder { - pub use nex_packet_builder::*; - } -} diff --git a/scripts/build-all.ps1 b/scripts/build-all.ps1 new file mode 100644 index 0000000..7c5581b --- /dev/null +++ b/scripts/build-all.ps1 @@ -0,0 +1,20 @@ +# target platforms +$targets = @( + "x86_64-unknown-linux-gnu", + "aarch64-unknown-linux-gnu", + "x86_64-unknown-freebsd", + "aarch64-linux-android", + "x86_64-linux-android" +) + +# cross build +foreach ($target in $targets) { + Write-Host "==> Building for $target..." + $result = & cross build --target $target + Write-Host "✅ Success: $target" + if ($LASTEXITCODE -ne 0) { + Write-Error "❌ Build failed for $target" + exit 1 + } +} +Write-Host "✅ All builds succeeded." diff --git a/scripts/build-all.sh b/scripts/build-all.sh new file mode 100755 index 0000000..8ef7b24 --- /dev/null +++ b/scripts/build-all.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e + +# target platforms +TARGETS=( + x86_64-unknown-linux-gnu + aarch64-unknown-linux-gnu + x86_64-unknown-freebsd + aarch64-linux-android + x86_64-linux-android +) + +# cross build +for target in "${TARGETS[@]}"; do + echo "==> Building for $target..." + if cross build --target "$target"; then + echo "✅ Success: $target" + else + echo "❌ Failed: $target" + exit 1 + fi +done + +echo "" +echo "✅ All builds succeeded!"