diff --git a/bindings/csharp/HelloWorld/Program.cs b/bindings/csharp/HelloWorld/Program.cs index 997ac93f..4cd1b0f2 100644 --- a/bindings/csharp/HelloWorld/Program.cs +++ b/bindings/csharp/HelloWorld/Program.cs @@ -71,7 +71,7 @@ static void Main(string[] args) { // Client MachnetFlow_t flow = new MachnetFlow_t(); - ret = MachnetShim.machnet_connect(channel_ctx, options.LocalIp, options.RemoteIp, kHelloWorldPort, ref flow); + ret = MachnetShim.machnet_connect(channel_ctx, options.LocalIp, options.RemoteIp, kHelloWorldPort, ref flow, 0); CustomCheck(ret == 0, "machnet_connect()"); string msg = "Hello World!"; @@ -84,7 +84,7 @@ static void Main(string[] args) else { Console.WriteLine("Waiting for message from client"); - ret = MachnetShim.machnet_listen(channel_ctx, options.LocalIp, kHelloWorldPort); + ret = MachnetShim.machnet_listen(channel_ctx, options.LocalIp, kHelloWorldPort, 0); CustomCheck(ret == 0, "machnet_listen()"); while (true) diff --git a/bindings/csharp/HelloWorld/machnet_shim.cs b/bindings/csharp/HelloWorld/machnet_shim.cs index 634ac648..3253d5ff 100644 --- a/bindings/csharp/HelloWorld/machnet_shim.cs +++ b/bindings/csharp/HelloWorld/machnet_shim.cs @@ -21,10 +21,10 @@ public static class MachnetShim public static extern IntPtr machnet_attach(); [DllImport(libmachnet_shim_location, CallingConvention = CallingConvention.Cdecl)] - public static extern int machnet_listen(IntPtr channel_ctx, string local_ip, UInt16 port); + public static extern int machnet_listen(IntPtr channel_ctx, string local_ip, UInt16 port, int protocol); [DllImport(libmachnet_shim_location, CallingConvention = CallingConvention.Cdecl)] - public static extern int machnet_connect(IntPtr channel_ctx, string local_ip, string remote_ip, UInt16 port, ref MachnetFlow_t flow); + public static extern int machnet_connect(IntPtr channel_ctx, string local_ip, string remote_ip, UInt16 port, ref MachnetFlow_t flow, int protocol); [DllImport(libmachnet_shim_location, CallingConvention = CallingConvention.Cdecl)] public static extern int machnet_send(IntPtr channel_ctx, MachnetFlow_t flow, byte[] data, IntPtr dataSize); diff --git a/bindings/go/machnet/conversion.h b/bindings/go/machnet/conversion.h index 9463e3f6..d2d9f2cf 100644 --- a/bindings/go/machnet/conversion.h +++ b/bindings/go/machnet/conversion.h @@ -45,13 +45,13 @@ MachnetFlow_t __machnet_recvmsg_go(const MachnetChannelCtx_t* ctx, int __machnet_connect_go(MachnetChannelCtx_t* ctx, uint32_t local_ip, uint32_t remote_ip, uint16_t remote_port, - MachnetFlow_t* flow) { - return machnet_connect(ctx, local_ip, remote_ip, remote_port, flow); + MachnetFlow_t* flow, int protocol) { + return machnet_connect(ctx, local_ip, remote_ip, remote_port, flow, protocol); } int __machnet_listen_go(MachnetChannelCtx_t* ctx, uint32_t local_ip, - uint16_t port) { - return machnet_listen(ctx, local_ip, port); + uint16_t port, int protocol) { + return machnet_listen(ctx, local_ip, port, protocol); } MachnetFlow_t* __machnet_init_flow() { diff --git a/bindings/go/machnet/machnet.go b/bindings/go/machnet/machnet.go index 3607352a..d8c68d6a 100644 --- a/bindings/go/machnet/machnet.go +++ b/bindings/go/machnet/machnet.go @@ -69,21 +69,21 @@ func Attach() *MachnetChannelCtx { } // Connect to the remote host and port. -func Connect(ctx *MachnetChannelCtx, local_ip string, remote_ip string, remote_port uint) (int, MachnetFlow) { +func Connect(ctx *MachnetChannelCtx, local_ip string, remote_ip string, remote_port uint, protocol int) (int, MachnetFlow) { // Initialize the flow var flow_ptr *C.MachnetFlow_t = C.__machnet_init_flow() local_ip_int := ipv4_str_to_uint32(local_ip) remote_ip_int := ipv4_str_to_uint32(remote_ip) - ret := C.__machnet_connect_go((*C.MachnetChannelCtx_t)(ctx), (C.uint)(local_ip_int), (C.uint)(remote_ip_int), C.ushort(remote_port), flow_ptr) + ret := C.__machnet_connect_go((*C.MachnetChannelCtx_t)(ctx), (C.uint)(local_ip_int), (C.uint)(remote_ip_int), C.ushort(remote_port), flow_ptr, C.int(protocol)) return (int)(ret), convert_net_flow_go(flow_ptr) } // Listen on the local host and port. -func Listen(ctx *MachnetChannelCtx, local_ip string, local_port uint) int { +func Listen(ctx *MachnetChannelCtx, local_ip string, local_port uint, protocol int) int { local_ip_int := ipv4_str_to_uint32(local_ip) - ret := C.__machnet_listen_go((*C.MachnetChannelCtx_t)(ctx), (C.uint)(local_ip_int), C.ushort(local_port)) + ret := C.__machnet_listen_go((*C.MachnetChannelCtx_t)(ctx), (C.uint)(local_ip_int), C.ushort(local_port), C.int(protocol)) return (int)(ret) } diff --git a/bindings/go/msg_gen/main.go b/bindings/go/msg_gen/main.go index 6dfb9009..2ebdf583 100644 --- a/bindings/go/msg_gen/main.go +++ b/bindings/go/msg_gen/main.go @@ -324,14 +324,14 @@ func main() { remote_ip, _ := jsonparser.GetString(json_bytes, "hosts_config", remote_hostname, "ipv4_addr") // Initiate connection to the remote host. - ret, flow = machnet.Connect(channel_ctx, local_ip, remote_ip, remote_port) + ret, flow = machnet.Connect(channel_ctx, local_ip, remote_ip, remote_port, 0) if ret != 0 { glog.Fatal("Failed to connect to remote host.") } glog.Info("[CONNECTED] [", local_ip, " <-> ", remote_ip, ":", remote_port, "]") } else { // Listen for incoming connections. - ret = machnet.Listen(channel_ctx, local_ip, remote_port) + ret = machnet.Listen(channel_ctx, local_ip, remote_port, 0) if ret != 0 { glog.Fatal("Failed to listen for incoming connections.") } diff --git a/bindings/js/hello_world.js b/bindings/js/hello_world.js index a7887c93..ac433d22 100644 --- a/bindings/js/hello_world.js +++ b/bindings/js/hello_world.js @@ -49,7 +49,7 @@ if (options.remote_ip) { const flow = new MachnetFlow_t(); ret = machnet_shim.machnet_connect( channel_ctx, ref.allocCString(options.local_ip), - ref.allocCString(options.remote_ip), kHelloWorldPort, flow.ref()); + ref.allocCString(options.remote_ip), kHelloWorldPort, flow.ref(), 0); customCheck(ret === 0, 'machnet_connect()'); const msg = 'Hello World!'; @@ -64,7 +64,7 @@ if (options.remote_ip) { // Server console.log('Waiting for message from client'); ret = machnet_shim.machnet_listen( - channel_ctx, ref.allocCString(options.local_ip), kHelloWorldPort); + channel_ctx, ref.allocCString(options.local_ip), kHelloWorldPort, 0); customCheck(ret === 0, 'machnet_listen()'); function receive_message() { diff --git a/bindings/js/latency.js b/bindings/js/latency.js index 52a68255..d8858ba7 100644 --- a/bindings/js/latency.js +++ b/bindings/js/latency.js @@ -51,11 +51,11 @@ const latencies_us = []; const tx_flow = new MachnetFlow_t(); var ret = machnet_shim.machnet_connect( channel_ctx, ref.allocCString(options.local_ip), - ref.allocCString(options.remote_ip), kHelloWorldPort, tx_flow.ref()); + ref.allocCString(options.remote_ip), kHelloWorldPort, tx_flow.ref(), 0); customCheck(ret === 0, 'machnet_connect()'); ret = machnet_shim.machnet_listen( - channel_ctx, ref.allocCString(options.local_ip), kHelloWorldPort); + channel_ctx, ref.allocCString(options.local_ip), kHelloWorldPort, 0); customCheck(ret === 0, 'machnet_listen()'); if (options.is_client) { diff --git a/bindings/js/machnet_shim.js b/bindings/js/machnet_shim.js index 3f41d078..fe4a5661 100644 --- a/bindings/js/machnet_shim.js +++ b/bindings/js/machnet_shim.js @@ -23,8 +23,8 @@ console.log('Loading libmachnet_shim'); var machnet_shim = ffi.Library(libmachnet_shim_location, { 'machnet_init': ['int', []], 'machnet_attach': ['pointer', []], - 'machnet_listen': ['int', [voidPtr, charPtr, uint16]], - 'machnet_connect': ['int', [voidPtr, charPtr, charPtr, uint16, MachnetFlowPtr]], + 'machnet_listen': ['int', [voidPtr, charPtr, uint16, 'int']], + 'machnet_connect': ['int', [voidPtr, charPtr, charPtr, uint16, MachnetFlowPtr, 'int']], 'machnet_send': ['int', [voidPtr, MachnetFlow_t, voidPtr, size_t]], 'machnet_recv': ['int', [voidPtr, voidPtr, size_t, MachnetFlowPtr]] }); diff --git a/bindings/js/rocksdb_client.js b/bindings/js/rocksdb_client.js index d8e77ca3..5e37c1a3 100644 --- a/bindings/js/rocksdb_client.js +++ b/bindings/js/rocksdb_client.js @@ -72,14 +72,16 @@ async function machnetTransportClientAsync() { ref.allocCString(options.local_ip), ref.allocCString(options.remote_ip), kRocksDbServerPort, - tx_flow.ref() + tx_flow.ref(), + 0 ); customCheck(ret === 0, "machnet_connect()"); ret = machnet_shim.machnet_listen( channel_ctx, ref.allocCString(options.local_ip), - kRocksDbServerPort + kRocksDbServerPort, + 0 ); customCheck(ret === 0, "machnet_listen()"); diff --git a/bindings/rust/resources/machnet.h b/bindings/rust/resources/machnet.h index 831e2798..3f028152 100644 --- a/bindings/rust/resources/machnet.h +++ b/bindings/rust/resources/machnet.h @@ -87,27 +87,30 @@ void *machnet_attach(); /** * @brief Listens for incoming messages on a specific IP and port. - * @param[in] channel The channel associated to the listener. - * @param[in] ip The local IP address to listen on. + * @param[in] channel_ctx The channel associated to the listener. + * @param[in] local_ip The local IP address to listen on. * @param[in] port The local port to listen on. + * @param[in] protocol MACHNET_PROTO_UDP (0) or MACHNET_PROTO_TCP (1). * @return 0 on success, -1 on failure. */ -int machnet_listen(void *channel_ctx, const char *local_ip, uint16_t port); +int machnet_listen(void *channel_ctx, const char *local_ip, uint16_t port, + int protocol); /** * @brief Creates a new connection to a remote peer. - * @param[in] channel The channel associated with the connection. + * @param[in] channel_ctx The channel associated with the connection. * @param[in] local_ip The local IP address. * @param[in] remote_ip The remote IP address. * @param[in] remote_port The remote port. * @param[out] flow A pointer to a `MachnetFlow_t` structure that will be * filled by the function upon success. + * @param[in] protocol MACHNET_PROTO_UDP (0) or MACHNET_PROTO_TCP (1). * @return 0 on success, -1 on failure. `flow` is filled with the flow * information on success. */ int machnet_connect(void *channel_ctx, const char *local_ip, const char *remote_ip, uint16_t remote_port, - MachnetFlow_t *flow); + MachnetFlow_t *flow, int protocol); /** * Enqueue one message for transmission to a remote peer over the network. diff --git a/bindings/rust/src/bindings.rs b/bindings/rust/src/bindings.rs index acba1561..39a5982a 100644 --- a/bindings/rust/src/bindings.rs +++ b/bindings/rust/src/bindings.rs @@ -529,6 +529,7 @@ extern "C" { channel_ctx: *mut ::std::os::raw::c_void, local_ip: *const ::std::os::raw::c_char, port: u16, + protocol: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } extern "C" { @@ -539,6 +540,7 @@ extern "C" { remote_ip: *const ::std::os::raw::c_char, remote_port: u16, flow: *mut MachnetFlow_t, + protocol: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } extern "C" { diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 44fb2e4f..c487fc15 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -151,7 +151,7 @@ pub fn machnet_attach<'a>() -> Option> { /// let remote_ip = "192.168.1.3"; /// let remote_port = 8080; /// -/// match machnet_connect(&mut channel, local_ip, remote_ip, remote_port) { +/// match machnet_connect(&mut channel, local_ip, remote_ip, remote_port, 0) { /// Some(flow) => { /// // Connection was successful, use `flow` here /// } @@ -166,6 +166,7 @@ pub fn machnet_connect( local_ip: &str, remote_ip: &str, remote_port: u16, + protocol: i32, ) -> Option { unsafe { let channel_ptr = channel.get_ptr(); @@ -180,6 +181,7 @@ pub fn machnet_connect( remote_ip_cstr.as_ptr(), remote_port, flow_ptr, // &mut flow, + protocol, ); match res { @@ -226,11 +228,11 @@ pub fn machnet_connect( /// } /// ``` /// -pub fn machnet_listen(channel: &mut MachnetChannel, local_ip: &str, local_port: u16) -> i32 { +pub fn machnet_listen(channel: &mut MachnetChannel, local_ip: &str, local_port: u16, protocol: i32) -> i32 { unsafe { let channel_ptr = channel.get_ptr(); let local_ip_cstr = CString::new(local_ip).unwrap(); - bindings::machnet_listen(channel_ptr, local_ip_cstr.as_ptr(), local_port) + bindings::machnet_listen(channel_ptr, local_ip_cstr.as_ptr(), local_port, protocol) } } diff --git a/docs/TCP_EXPERIMENT_README.md b/docs/TCP_EXPERIMENT_README.md new file mode 100644 index 00000000..88451337 --- /dev/null +++ b/docs/TCP_EXPERIMENT_README.md @@ -0,0 +1,222 @@ +# Running TCP Experiments with Machnet + +This guide shows how to run TCP-based applications with Machnet. The key idea: +**a standard Linux TCP client (using POSIX sockets) can talk to a Machnet +server running a DPDK kernel-bypass TCP stack** — getting sub-100μs tail +latency on the server side without any changes to the client. + +## Wire Protocol + +Machnet TCP uses a simple framing format. Every message on the wire is: + +``` +[4-byte big-endian length prefix][payload bytes] +``` + +Any client that sends/receives with this framing can interoperate. + +**Prerequisites:** Two machines with Machnet already running on at least one of +them. Follow the main [README.md](../README.md) for NIC setup and starting the +Machnet process. The client VM needs nothing special — any Linux machine works. + +--- + +## Experiment 1: Machnet Server + Standard TCP Client + +The server uses DPDK kernel-bypass, the client uses regular Linux sockets. + +### Run the Machnet TCP server (msg_gen) + +On the server VM (Machnet process already running per the main README): + +```bash +MSG_GEN="docker run -v /var/run/machnet:/var/run/machnet \ + ghcr.io/microsoft/machnet/machnet:latest \ + release_build/src/apps/msg_gen/msg_gen" + +# Start TCP server listening on port 888 +${MSG_GEN} --local_ip ${MACHNET_IP} --local_port 888 --protocol tcp --msg_size 64 +``` + +You should see: +``` +Using TCP transport +Starting in server mode, response size 64 +[LISTENING] [10.0.0.2:888] +``` + +### Run the standard TCP client (tcp_msg_gen) + +On the **client VM** (no Machnet needed): + +```bash +# Build tcp_msg_gen — pure POSIX sockets, no DPDK dependency +cd machnet/build +cmake .. && ninja -j$(nproc) tcp_msg_gen + +# Connect to the Machnet server +./src/apps/tcp_msg_gen/tcp_msg_gen \ + --local_ip=0.0.0.0 \ + --remote_ip= \ + --remote_port=888 \ + --msg_size=64 \ + --msg_window=8 +``` + +You should see throughput stats printed every second: +``` +[CLIENT] Connecting to 10.0.0.2:888, msg_size=64, window=8 +[CLIENT] Connected. +TX/RX (msg/sec, Gbps): (142.3K/142.3K, 0.073/0.073) +TX/RX (msg/sec, Gbps): (145.1K/145.1K, 0.074/0.074) +``` + +### (Optional) Use a Python client instead + +You can also connect from any language. Here's a minimal Python client: + +```python +#!/usr/bin/env python3 +"""Minimal TCP client that speaks Machnet's wire protocol.""" +import socket, struct, time + +SERVER_IP = "" +SERVER_PORT = 888 +MSG_SIZE = 64 + +sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) +sock.connect((SERVER_IP, SERVER_PORT)) + +payload = bytes(MSG_SIZE) +count = 0 +t0 = time.time() + +while True: + # Send: 4-byte BE length prefix + payload + sock.sendall(struct.pack('>I', len(payload)) + payload) + + # Recv: read 4-byte prefix, then payload + hdr = sock.recv(4, socket.MSG_WAITALL) + if len(hdr) < 4: + break + msg_len = struct.unpack('>I', hdr)[0] + data = sock.recv(msg_len, socket.MSG_WAITALL) + + count += 1 + elapsed = time.time() - t0 + if elapsed >= 1.0: + print(f"{count/elapsed:.0f} msg/sec") + count = 0 + t0 = time.time() +``` + +--- + +## Experiment 2: Machnet-to-Machnet TCP + +Both sides run Machnet (kernel-bypass on both endpoints). + +### Server VM + +```bash +${MSG_GEN} --local_ip ${MACHNET_IP} --local_port 888 --protocol tcp --msg_size 64 +``` + +### Client VM (Machnet also running) + +```bash +${MSG_GEN} --local_ip ${CLIENT_IP} \ + --remote_ip ${SERVER_IP} \ + --remote_port 888 \ + --protocol tcp \ + --msg_size 64 \ + --msg_window 8 +``` + +--- + +## Experiment 3: Standard TCP Server + Machnet Client + +A plain Linux TCP server paired with a Machnet kernel-bypass client. + +### Server VM (no Machnet needed) + +```bash +# Build tcp_msg_gen (same build as Experiment 1 client) +./src/apps/tcp_msg_gen/tcp_msg_gen \ + --local_ip=0.0.0.0 \ + --local_port=888 \ + --msg_size=64 +``` + +### Client VM (running Machnet) + +```bash +${MSG_GEN} --local_ip ${MACHNET_IP} \ + --remote_ip \ + --remote_port 888 \ + --protocol tcp \ + --msg_size 64 \ + --msg_window 8 +``` + +--- + +## Experiment Summary + +| Experiment | Client | Server | What it shows | +|---|---|---|---| +| 1 | `tcp_msg_gen` (POSIX sockets) | `msg_gen --protocol tcp` (Machnet) | Standard TCP client → kernel-bypass server | +| 2 | `msg_gen --protocol tcp` (Machnet) | `msg_gen --protocol tcp` (Machnet) | Full kernel-bypass on both sides | +| 3 | `msg_gen --protocol tcp` (Machnet) | `tcp_msg_gen` (POSIX sockets) | Kernel-bypass client → standard TCP server | + +--- + +## Configuration Reference + +### msg_gen flags (Machnet TCP) + +| Flag | Default | Description | +|---|---|---| +| `--local_ip` | (required) | IP of the local Machnet interface | +| `--remote_ip` | (empty = server) | IP of the remote peer (set to run as client) | +| `--local_port` | 888 | Local port to listen on | +| `--remote_port` | 888 | Remote port to connect to | +| `--protocol` | `udp` | Transport protocol: `udp` or `tcp` | +| `--msg_size` | 64 | Message payload size in bytes (minimum 9) | +| `--msg_window` | 8 | Number of messages in flight (client only) | +| `--msg_nr` | unlimited | Total number of messages to send | +| `--verify` | false | Verify payload correctness | + +### tcp_msg_gen flags (standard POSIX TCP) + +| Flag | Default | Description | +|---|---|---| +| `--local_ip` | `0.0.0.0` | Local bind address | +| `--local_port` | 888 | Local port (server mode) | +| `--remote_ip` | (empty = server) | Remote address (set to run as client) | +| `--remote_port` | 888 | Remote port (client mode) | +| `--msg_size` | 64 | Message payload size in bytes (minimum 8) | +| `--msg_window` | 8 | Number of messages in flight (client only) | + +--- + +## Troubleshooting + +**`tcp_msg_gen` fails to connect:** +- Verify the Machnet server is running and listening (`[LISTENING]` in the log). +- Check firewall rules allow traffic on port 888 between the two VMs. +- Ensure the Machnet NIC is bound to DPDK (it won't respond to `ping`). + +**Low throughput:** +- Increase `--msg_window` for more pipelining (try 16, 32, 64). +- Increase `--msg_size` if your workload has larger messages. +- Ensure `TCP_NODELAY` is enabled on the client (both `tcp_msg_gen` and + the Python example do this). + +**Connection reset / drops:** +- Check that both sides use the same `--msg_size` — the app header + (`window_slot`) must be present in every message. +- Ensure hugepages are configured on the Machnet VM + (`echo 2048 | sudo tee /sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages`). diff --git a/docs/tests/TCP_FLOW_TESTS_README.md b/docs/tests/TCP_FLOW_TESTS_README.md new file mode 100644 index 00000000..69ead452 --- /dev/null +++ b/docs/tests/TCP_FLOW_TESTS_README.md @@ -0,0 +1,60 @@ +# TCP Flow Unit Tests + +## What + +`tcp_flow_test.cc` is a comprehensive unit test suite for the `TcpFlow` class (`src/include/tcp_flow.h`), which implements Machnet's TCP transport path. The suite contains **50 tests** covering the full TCP state machine, data path, and edge cases that arise when interoperating with the Linux kernel's TCP stack. + +## Why + +We added TCP support to Machnet alongside the existing UDP-based Machnet protocol. The `TcpFlow` class is a complex piece of code — it implements a TCP state machine, message framing/deframing over a byte stream, MSS negotiation, retransmission overlap handling, and interfaces with the shared-memory channel system. Any regression in this code could silently corrupt data or hang connections. + +These tests exist to: + +1. **Catch regressions** — the TCP code touches many subsystems (packet construction, channel MsgBuf allocation, sequence number tracking). A change in any of them could break TCP without affecting UDP. +2. **Validate bug fixes** — several bugs were found and fixed during development (Ethernet padding causing wrong payload length, IP `total_length` vs `packet->length()` mismatch, invalid TCP header lengths). Each fix has a corresponding test to prevent re-introduction. +3. **Document behavior** — the tests serve as executable documentation of how the TCP state machine behaves in specific scenarios (retransmitted SYNs in ESTABLISHED, piggybacked data on handshake ACK, partial retransmission overlaps). + +## Test Categories + +| Category | # Tests | Description | +|---|---|---| +| Sequence number helpers | 1 | Verifies wrap-around-safe comparisons (`SeqLt`, `SeqGt`, etc.) including the `0xFFFFFFFF → 0` boundary | +| Construction & initial state | 1 | Checks default values: state=CLOSED, ISN initialization, default MSS | +| Active open (client handshake) | 3 | `InitiateHandshake` → SYN_SENT, full 3-way handshake with callback, wrong-ACK rejection | +| Passive open (server handshake) | 4 | `StartPassiveOpen` → LISTEN, full handshake, MSS option parsing from SYN, piggybacked data on completing ACK | +| RST handling | 2 | RST received in ESTABLISHED and SYN_SENT states | +| TCP option parsing | 4 | MSS-only option, NOP-padded options, no options (default MSS), zero-MSS fallback | +| Header validation | 1 | Rejects TCP headers with length < 20 bytes | +| TX path (OutputMessage) | 3 | Small message framing, not-established guard, multi-segment when payload > MSS | +| RX deframing (ConsumePayload) | 3 | Single complete message, length prefix split across calls, two back-to-back messages | +| Established data handling | 2 | Data packet delivery to channel, retransmitted SYN re-ACK in ESTABLISHED | +| Retransmission overlap | 4 | Exact in-order delivery, partial overlap (kernel retransmit), pure duplicate, empty payload | +| FIN handling | 3 | Clean FIN → CLOSE_WAIT, FIN with trailing data, retransmitted FIN | +| Shutdown (active close) | 3 | From ESTABLISHED → FIN_WAIT_1, CLOSE_WAIT → LAST_ACK, SYN_SENT → RST + CLOSED | +| Graceful close (full teardown) | 2 | FIN_WAIT_1 → FIN_WAIT_2 → TIME_WAIT, passive close via LAST_ACK → CLOSED | +| PeriodicCheck (timers) | 4 | CLOSED returns false, TIME_WAIT countdown, SYN retransmit on RTO, max retransmissions exceeded | +| AdvanceSndUna | 4 | Normal ACK advancement, all-data-acked (RTO disarm), stale ACK ignored, future ACK ignored | +| StateToString | 1 | All 10 TCP states map to correct strings | +| Match | 2 | Correct 4-tuple match, wrong source address rejection | +| Ethernet padding resilience | 1 | Verifies IP `total_length` is used (not `packet->length()`) for payload size | +| IP length validation | 1 | Rejects packets where IP `total_length` < IP header + TCP header | + +## How to Build and Run + +```bash +cd build +cmake .. # Only needed once or after adding new test files +ninja -j$(nproc) tcp_flow_test +sudo ./src/tests/tcp_flow_test +``` + +`sudo` is required because DPDK needs access to hugepages. The test uses `--vdev=net_null0,copy=1 --no-pci` so no physical NIC is needed. + +## Test Infrastructure + +- **Framework**: Google Test (gtest) +- **DPDK**: Initialized with a `net_null` virtual device — no hardware required +- **White-box testing**: Uses `#define private public` (applied only to `tcp_flow.h`) to access internal state for assertions +- **Fixture**: `TcpFlowTest` with a static `PmdPort`/`TxRing` (shared across tests to avoid DPDK port teardown issues) and per-test `Channel` + `PacketPool` +- **Packet crafting**: `MakePacket()` helper builds Ethernet + IPv4 + TCP packets with correct headers, optional TCP options, and optional payload +- **Auto-discovery**: CMake globs `*_test.cc` under `src/`, so the test is automatically picked up without modifying `CMakeLists.txt` diff --git a/examples/hello_world.cc b/examples/hello_world.cc index ed347352..dcbcaa85 100644 --- a/examples/hello_world.cc +++ b/examples/hello_world.cc @@ -34,7 +34,7 @@ int main(int argc, char *argv[]) { void *channel = machnet_attach(); assert_with_msg(channel != nullptr, "machnet_attach() failed"); - ret = machnet_listen(channel, FLAGS_local.c_str(), kPort); + ret = machnet_listen(channel, FLAGS_local.c_str(), kPort, MACHNET_PROTO_UDP); assert_with_msg(ret == 0, "machnet_listen() failed"); printf("Listening on %s:%d\n", FLAGS_local.c_str(), kPort); @@ -44,7 +44,7 @@ int main(int argc, char *argv[]) { MachnetFlow flow; std::string msg = "Hello World!"; ret = machnet_connect(channel, FLAGS_local.c_str(), FLAGS_remote.c_str(), - kPort, &flow); + kPort, &flow, MACHNET_PROTO_UDP); assert_with_msg(ret == 0, "machnet_connect() failed"); const int ret = machnet_send(channel, flow, msg.data(), msg.size()); diff --git a/examples/rust/src/main.rs b/examples/rust/src/main.rs index e67dc6e8..e1b2769e 100644 --- a/examples/rust/src/main.rs +++ b/examples/rust/src/main.rs @@ -428,7 +428,7 @@ fn main() { let datapath_thread = match msg_gen.remote_ip.is_empty() { true => { // Server-mode - let ret = machnet_listen(&mut channel_ctx, &msg_gen.local_ip, msg_gen.local_port); + let ret = machnet_listen(&mut channel_ctx, &msg_gen.local_ip, msg_gen.local_port, 0); assert_eq!(ret, 0, "Failed to listen on local port."); info!("[LISTENING] [{}:{}]", msg_gen.local_ip, msg_gen.local_port); @@ -441,6 +441,7 @@ fn main() { &msg_gen.local_ip, &msg_gen.remote_ip, msg_gen.remote_port, + 0, ) .expect("Failed to connect to remote host."); diff --git a/examples/sync.sh b/examples/sync.sh new file mode 100755 index 00000000..f6913c0c --- /dev/null +++ b/examples/sync.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +# Sync the machnet workspace to the remote machine, preserving the same path. + +REMOTE_USER="sarsanaee" +REMOTE_HOST="asas-westus2-vm-0" +LOCAL_PATH="/home/sarsanaee/machnet/" +REMOTE_PATH="${REMOTE_USER}@${REMOTE_HOST}:/home/sarsanaee/machnet/" + +rsync -avz --progress \ + --exclude 'build/' \ + --exclude '.git/' \ + --exclude 'third_party/' \ + "$LOCAL_PATH" "$REMOTE_PATH" diff --git a/src/apps/CMakeLists.txt b/src/apps/CMakeLists.txt index a4b3c39e..b990cc99 100644 --- a/src/apps/CMakeLists.txt +++ b/src/apps/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(../modules) add_subdirectory(machnet) add_subdirectory(msg_gen) add_subdirectory(rocksdb_server) +add_subdirectory(tcp_msg_gen) diff --git a/src/apps/msg_gen/main.cc b/src/apps/msg_gen/main.cc index 1e98ccb1..07be5e22 100644 --- a/src/apps/msg_gen/main.cc +++ b/src/apps/msg_gen/main.cc @@ -30,6 +30,7 @@ DEFINE_uint32(msg_size, 64, "Size of the message (request/response) to send."); DEFINE_uint32(msg_window, 8, "Maximum number of messages in flight."); DEFINE_uint64(msg_nr, UINT64_MAX, "Number of messages to send."); DEFINE_bool(verify, false, "Verify payload of received messages."); +DEFINE_string(protocol, "udp", "Transport protocol: 'udp' or 'tcp'."); static volatile int g_keep_running = 1; @@ -326,6 +327,16 @@ int main(int argc, char *argv[]) { FLAGS_logtostderr = 1; CHECK_GT(FLAGS_msg_size, sizeof(app_hdr_t)) << "Message size too small"; + + int proto = MACHNET_PROTO_UDP; + if (FLAGS_protocol == "tcp") { + proto = MACHNET_PROTO_TCP; + LOG(INFO) << "Using TCP transport"; + } else { + CHECK(FLAGS_protocol == "udp") << "Unknown protocol: " << FLAGS_protocol; + LOG(INFO) << "Using UDP transport"; + } + if (FLAGS_remote_ip == "") { LOG(INFO) << "Starting in server mode, response size " << FLAGS_msg_size; } else { @@ -342,7 +353,8 @@ int main(int argc, char *argv[]) { // Client-mode int ret = machnet_connect(channel_ctx, FLAGS_local_ip.c_str(), - FLAGS_remote_ip.c_str(), FLAGS_remote_port, &flow); + FLAGS_remote_ip.c_str(), FLAGS_remote_port, &flow, + proto); CHECK(ret == 0) << "Failed to connect to remote host. machnet_connect() " "error: " << strerror(ret); @@ -353,7 +365,8 @@ int main(int argc, char *argv[]) { datapath_thread = std::thread(ClientLoop, channel_ctx, &flow); } else { int ret = - machnet_listen(channel_ctx, FLAGS_local_ip.c_str(), FLAGS_local_port); + machnet_listen(channel_ctx, FLAGS_local_ip.c_str(), FLAGS_local_port, + proto); CHECK(ret == 0) << "Failed to listen on local port. machnet_listen() error: " << strerror(ret); diff --git a/src/apps/rocksdb_server/rocksdb_server.cc b/src/apps/rocksdb_server/rocksdb_server.cc index 1bd3332d..9a2089b9 100644 --- a/src/apps/rocksdb_server/rocksdb_server.cc +++ b/src/apps/rocksdb_server/rocksdb_server.cc @@ -31,7 +31,7 @@ void MachnetTransportServer(rocksdb::DB *db) { void *channel = machnet_attach(); CHECK(channel != nullptr) << "machnet_attach() failed"; - ret = machnet_listen(channel, FLAGS_local.c_str(), kPort); + ret = machnet_listen(channel, FLAGS_local.c_str(), kPort, MACHNET_PROTO_UDP); CHECK_EQ(ret, 0) << "machnet_listen() failed"; // Handle client requests diff --git a/src/apps/tcp_msg_gen/CMakeLists.txt b/src/apps/tcp_msg_gen/CMakeLists.txt new file mode 100644 index 00000000..7c73d84b --- /dev/null +++ b/src/apps/tcp_msg_gen/CMakeLists.txt @@ -0,0 +1,3 @@ +set(target_name tcp_msg_gen) +add_executable(${target_name} main.cc) +target_compile_features(${target_name} PRIVATE cxx_std_17) diff --git a/src/apps/tcp_msg_gen/main.cc b/src/apps/tcp_msg_gen/main.cc new file mode 100644 index 00000000..62dd7f41 --- /dev/null +++ b/src/apps/tcp_msg_gen/main.cc @@ -0,0 +1,322 @@ +/** + * @file tcp_msg_gen.cc + * @brief Standalone Linux TCP message generator that speaks the same + * wire protocol as Machnet's msg_gen running with --protocol=tcp. + * + * Wire format (Machnet TCP framing): + * [4-byte big-endian message length][payload] + * + * The payload starts with an 8-byte app_hdr_t (window_slot), followed by + * filler bytes up to --msg_size. + * + * Usage: + * Server (pairs with Machnet msg_gen client): + * ./tcp_msg_gen --local_ip=0.0.0.0 --local_port=888 --msg_size=64 + * + * Client (pairs with Machnet msg_gen server): + * ./tcp_msg_gen --local_ip=0.0.0.0 --remote_ip=10.0.0.2 \ + * --remote_port=888 --msg_size=64 --msg_window=8 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// ──────── Configuration (mirrors msg_gen flags) ──────── + +static const char *g_local_ip = "0.0.0.0"; +static const char *g_remote_ip = nullptr; +static uint16_t g_local_port = 888; +static uint16_t g_remote_port = 888; +static uint32_t g_msg_size = 64; +static uint32_t g_msg_window = 8; + +static volatile int g_keep_running = 1; +void sig_handler(int) { g_keep_running = 0; } + +// ──────── app_hdr_t must match msg_gen exactly ──────── + +struct app_hdr_t { + uint64_t window_slot; +}; + +// ──────── Helpers ──────── + +/// Read exactly `len` bytes (blocking). +static bool read_exact(int fd, void *buf, size_t len) { + uint8_t *p = static_cast(buf); + while (len > 0) { + ssize_t n = read(fd, p, len); + if (n <= 0) return false; + p += n; + len -= n; + } + return true; +} + +/// Write exactly `len` bytes (blocking). +static bool write_exact(int fd, const void *buf, size_t len) { + const uint8_t *p = static_cast(buf); + while (len > 0) { + ssize_t n = write(fd, p, len); + if (n <= 0) return false; + p += n; + len -= n; + } + return true; +} + +/// Receive one framed message: 4-byte BE length prefix + payload. +/// Returns payload size, or -1 on error/disconnect. +static ssize_t recv_framed(int fd, void *buf, size_t buf_cap) { + uint32_t net_len; + if (!read_exact(fd, &net_len, 4)) return -1; + uint32_t msg_len = be32toh(net_len); + if (msg_len > buf_cap) { + fprintf(stderr, "Message too large: %u > %zu\n", msg_len, buf_cap); + return -1; + } + if (!read_exact(fd, buf, msg_len)) return -1; + return static_cast(msg_len); +} + +/// Send one framed message: 4-byte BE length prefix + payload. +static bool send_framed(int fd, const void *buf, size_t len) { + uint32_t net_len = htobe32(static_cast(len)); + if (!write_exact(fd, &net_len, 4)) return false; + if (!write_exact(fd, buf, len)) return false; + return true; +} + +// ──────── Stats ──────── + +struct Stats { + uint64_t tx_count = 0, tx_bytes = 0; + uint64_t rx_count = 0, rx_bytes = 0; +}; + +static void report_stats(Stats &cur, Stats &prev, + std::chrono::steady_clock::time_point &last_time) { + auto now = std::chrono::steady_clock::now(); + double elapsed = + std::chrono::duration(now - last_time).count(); + if (elapsed < 1.0) return; + + double tx_kmps = (cur.tx_count - prev.tx_count) / (1000 * elapsed); + double rx_kmps = (cur.rx_count - prev.rx_count) / (1000 * elapsed); + double tx_gbps = + ((cur.tx_bytes - prev.tx_bytes) * 8) / (elapsed * 1E9); + double rx_gbps = + ((cur.rx_bytes - prev.rx_bytes) * 8) / (elapsed * 1E9); + + std::cout << "TX/RX (msg/sec, Gbps): (" << std::fixed + << std::setprecision(1) << tx_kmps << "K/" << rx_kmps << "K" + << std::fixed << std::setprecision(3) << ", " << tx_gbps << "/" + << rx_gbps << ")" << std::endl; + + prev = cur; + last_time = now; +} + +// ──────── Server mode ──────── +// Listens, accepts one connection, echoes back every message. + +static int run_server() { + int listenfd = socket(AF_INET, SOCK_STREAM, 0); + if (listenfd < 0) { perror("socket"); return 1; } + + int opt = 1; + setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + setsockopt(listenfd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + + struct sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(g_local_port); + inet_pton(AF_INET, g_local_ip, &addr.sin_addr); + + if (bind(listenfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + perror("bind"); return 1; + } + if (listen(listenfd, 1) < 0) { perror("listen"); return 1; } + + std::cout << "[SERVER] Listening on " << g_local_ip << ":" << g_local_port + << ", msg_size=" << g_msg_size << std::endl; + + struct sockaddr_in peer{}; + socklen_t peer_len = sizeof(peer); + int connfd = accept(listenfd, (struct sockaddr *)&peer, &peer_len); + if (connfd < 0) { perror("accept"); return 1; } + + opt = 1; + setsockopt(connfd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + char peer_ip[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &peer.sin_addr, peer_ip, sizeof(peer_ip)); + std::cout << "[SERVER] Accepted connection from " << peer_ip << ":" + << ntohs(peer.sin_port) << std::endl; + + std::vector rx_buf(8 * 1024 * 1024); + std::vector tx_buf(g_msg_size); + std::iota(tx_buf.begin(), tx_buf.end(), 0); + + Stats cur{}, prev{}; + auto last_time = std::chrono::steady_clock::now(); + + while (g_keep_running) { + ssize_t rx_size = recv_framed(connfd, rx_buf.data(), rx_buf.size()); + if (rx_size <= 0) break; + cur.rx_count++; + cur.rx_bytes += rx_size; + + // Copy the app header (window_slot) into the response. + const app_hdr_t *req = + reinterpret_cast(rx_buf.data()); + app_hdr_t *resp = reinterpret_cast(tx_buf.data()); + resp->window_slot = req->window_slot; + + if (!send_framed(connfd, tx_buf.data(), g_msg_size)) break; + cur.tx_count++; + cur.tx_bytes += g_msg_size; + + report_stats(cur, prev, last_time); + } + + std::cout << "[SERVER] Done. TX=" << cur.tx_count << " RX=" << cur.rx_count + << std::endl; + close(connfd); + close(listenfd); + return 0; +} + +// ──────── Client mode ──────── +// Connects to a remote server, sends a window of messages, then ping-pongs. + +static int run_client() { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { perror("socket"); return 1; } + + int opt = 1; + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + + struct sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(g_remote_port); + inet_pton(AF_INET, g_remote_ip, &addr.sin_addr); + + std::cout << "[CLIENT] Connecting to " << g_remote_ip << ":" + << g_remote_port << ", msg_size=" << g_msg_size + << ", window=" << g_msg_window << std::endl; + + if (connect(fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + perror("connect"); return 1; + } + std::cout << "[CLIENT] Connected." << std::endl; + + std::vector tx_buf(g_msg_size); + std::vector rx_buf(8 * 1024 * 1024); + std::iota(tx_buf.begin(), tx_buf.end(), 0); + + Stats cur{}, prev{}; + auto last_time = std::chrono::steady_clock::now(); + + // Send initial window. + for (uint32_t i = 0; i < g_msg_window; i++) { + app_hdr_t *hdr = reinterpret_cast(tx_buf.data()); + hdr->window_slot = i; + if (!send_framed(fd, tx_buf.data(), g_msg_size)) { + fprintf(stderr, "Failed to send initial window slot %u\n", i); + return 1; + } + cur.tx_count++; + cur.tx_bytes += g_msg_size; + } + + // Steady-state: receive one, send one. + while (g_keep_running) { + ssize_t rx_size = recv_framed(fd, rx_buf.data(), rx_buf.size()); + if (rx_size <= 0) break; + cur.rx_count++; + cur.rx_bytes += rx_size; + + const app_hdr_t *resp = + reinterpret_cast(rx_buf.data()); + uint64_t slot = resp->window_slot; + + // Send next message for this slot. + app_hdr_t *req = reinterpret_cast(tx_buf.data()); + req->window_slot = slot; + if (!send_framed(fd, tx_buf.data(), g_msg_size)) break; + cur.tx_count++; + cur.tx_bytes += g_msg_size; + + report_stats(cur, prev, last_time); + } + + std::cout << "[CLIENT] Done. TX=" << cur.tx_count << " RX=" << cur.rx_count + << std::endl; + close(fd); + return 0; +} + +// ──────── Main ──────── + +static void usage(const char *prog) { + fprintf(stderr, + "Usage: %s --local_ip=IP [--local_port=PORT] [--remote_ip=IP] " + "[--remote_port=PORT] [--msg_size=N] [--msg_window=N]\n" + "\n" + " If --remote_ip is given, runs as client; otherwise as server.\n", + prog); +} + +int main(int argc, char *argv[]) { + signal(SIGINT, sig_handler); + signal(SIGPIPE, SIG_IGN); + + // Minimal flag parsing (--key=value style). + for (int i = 1; i < argc; i++) { + std::string arg(argv[i]); + auto eq = arg.find('='); + if (eq == std::string::npos) { usage(argv[0]); return 1; } + std::string key = arg.substr(0, eq); + std::string val = arg.substr(eq + 1); + + if (key == "--local_ip") g_local_ip = strdup(val.c_str()); + else if (key == "--local_port") g_local_port = atoi(val.c_str()); + else if (key == "--remote_ip") g_remote_ip = strdup(val.c_str()); + else if (key == "--remote_port") g_remote_port = atoi(val.c_str()); + else if (key == "--msg_size") g_msg_size = atoi(val.c_str()); + else if (key == "--msg_window") g_msg_window = atoi(val.c_str()); + else { fprintf(stderr, "Unknown flag: %s\n", key.c_str()); return 1; } + } + + if (g_msg_size < sizeof(app_hdr_t)) { + fprintf(stderr, "msg_size must be >= %zu\n", sizeof(app_hdr_t)); + return 1; + } + + if (g_remote_ip != nullptr) { + return run_client(); + } else { + return run_server(); + } +} diff --git a/src/core/drivers/shm/channel.cc b/src/core/drivers/shm/channel.cc index 3e9e82e5..d3c511ff 100644 --- a/src/core/drivers/shm/channel.cc +++ b/src/core/drivers/shm/channel.cc @@ -6,6 +6,7 @@ #include #include #include +#include namespace juggler { namespace shm { @@ -171,5 +172,10 @@ void Channel::RemoveFlow( active_flows_.erase(flow_it); } +void Channel::RemoveTcpFlow( + const std::list>::const_iterator &flow_it) { + active_tcp_flows_.erase(flow_it); +} + } // namespace shm } // namespace juggler diff --git a/src/core/net/tcp.cc b/src/core/net/tcp.cc new file mode 100644 index 00000000..e3fa5829 --- /dev/null +++ b/src/core/net/tcp.cc @@ -0,0 +1,17 @@ +#include + +namespace juggler { +namespace net { + +std::string Tcp::ToString() const { + return juggler::utils::Format( + "[TCP: src_port %u, dst_port %u, seq %u, ack %u, flags 0x%02x, " + "win %u]", + static_cast(src_port.port.value()), + static_cast(dst_port.port.value()), seq_num.value(), + ack_num.value(), static_cast(flags), + static_cast(window.value())); +} + +} // namespace net +} // namespace juggler diff --git a/src/core/tcp_flow_test.cc b/src/core/tcp_flow_test.cc new file mode 100644 index 00000000..871d9cd7 --- /dev/null +++ b/src/core/tcp_flow_test.cc @@ -0,0 +1,1266 @@ +/** + * @file tcp_flow_test.cc + * + * Unit tests for the TcpFlow class (src/include/tcp_flow.h). + * + * These tests exercise: + * - TCP state machine transitions (handshake, shutdown, RST) + * - MSS option parsing from TCP headers + * - Data TX path (OutputMessage with length-prefix framing) + * - Data RX path (ConsumePayload de-framing into MsgBufs) + * - Retransmission overlap handling (ProcessInOrderPayload) + * - PeriodicCheck (RTO, TIME_WAIT countdown) + * - Sequence number comparison helpers + */ + +// White-box testing: access private members of TcpFlow. +// The define must come AFTER all system/STL/gtest includes. + +#include +#include + +#include +#include +#include +#include +#include + +#include "channel.h" +#include "channel_msgbuf.h" +#include "common.h" +#include "dpdk.h" +#include "ether.h" +#include "ipv4.h" +#include "machnet.h" +#include "machnet_common.h" +#include "packet.h" +#include "packet_pool.h" +#include "pmd.h" +#include "tcp.h" +#include "utils.h" + +// Access private members of TcpFlow. +#define private public +#define protected public + +#include "tcp_flow.h" + +#undef private +#undef protected + +namespace juggler { +namespace net { +namespace flow { + +// ───────────────── Test Fixture ───────────────── + +class TcpFlowTest : public ::testing::Test { + protected: + static constexpr uint32_t kChannelRingSize = 1024; + + // PmdPort and its TxRing are expensive to create and cannot survive + // destroy-then-recreate, so we share them across all tests. + static std::shared_ptr pmd_port_; + static dpdk::TxRing* txring_; + + static void SetUpTestSuite() { + // Use port 1 (the net_null virtual device). + pmd_port_ = std::make_shared(1, 1, 1, 512, 512); + pmd_port_->InitDriver(); + txring_ = pmd_port_->GetRing(0); + } + static void TearDownTestSuite() { + pmd_port_.reset(); + } + + void SetUp() override { + // Set up addresses. + local_addr_.FromString("10.0.0.1"); + remote_addr_.FromString("10.0.0.2"); + local_port_ = Tcp::Port(5000); + remote_port_ = Tcp::Port(6000); + local_mac_ = Ethernet::Address("00:00:00:00:00:01"); + remote_mac_ = Ethernet::Address("00:00:00:00:00:02"); + + // Create a channel for MsgBuf allocation. + static int channel_id = 0; + std::string chan_name = "tcp_flow_test_" + std::to_string(channel_id++); + channel_mgr_.AddChannel(chan_name.c_str(), kChannelRingSize, + kChannelRingSize, kChannelRingSize, + kChannelRingSize); + channel_ = channel_mgr_.GetChannel(chan_name.c_str()); + + // Create a separate packet pool for crafting RX packets. + pkt_pool_ = std::make_unique( + 8192, dpdk::PmdRing::kDefaultFrameSize); + + // Callback tracking. + callback_called_ = false; + callback_success_ = false; + } + + void TearDown() override {} + + // ── Helpers ── + + /// Application callback that records whether it was invoked. + void AppCallback(shm::Channel* /*channel*/, bool success, const Key& /*key*/) { + callback_called_ = true; + callback_success_ = success; + } + + /// Create a new TcpFlow in kClosed state. + std::unique_ptr MakeFlow() { + return std::make_unique( + local_addr_, local_port_, remote_addr_, remote_port_, local_mac_, + remote_mac_, txring_, [this](auto* ch, bool ok, const auto& k) { + AppCallback(ch, ok, k); + }, + channel_.get()); + } + + /** + * @brief Build a fake TCP packet (Ethernet + IP + TCP + payload). + * + * The caller specifies seq, ack, flags, and an optional payload buffer. + * TCP header length is 20 bytes (no options) unless extra_opts is supplied. + */ + dpdk::Packet* MakePacket(uint32_t seq, uint32_t ack, uint8_t flags, + const uint8_t* payload = nullptr, + size_t payload_len = 0, + const uint8_t* tcp_opts = nullptr, + size_t tcp_opts_len = 0) { + const size_t tcp_hdr_len = sizeof(Tcp) + tcp_opts_len; + const size_t hdr_len = sizeof(Ethernet) + sizeof(Ipv4) + tcp_hdr_len; + const size_t total_len = hdr_len + payload_len; + + dpdk::Packet* pkt = nullptr; + CHECK(pkt_pool_->PacketBulkAlloc(&pkt, 1)); + dpdk::Packet::Reset(pkt); + CHECK_NOTNULL(pkt->append(static_cast(total_len))); + + // Ethernet header. + auto* eh = pkt->head_data(); + eh->src_addr = remote_mac_; + eh->dst_addr = local_mac_; + eh->eth_type = be16_t(Ethernet::kIpv4); + + // IPv4 header. total_length = IP + TCP + payload (no Ethernet). + auto* ipv4h = pkt->head_data(sizeof(Ethernet)); + ipv4h->version_ihl = 0x45; + ipv4h->type_of_service = 0; + ipv4h->total_length = + be16_t(static_cast(sizeof(Ipv4) + tcp_hdr_len + payload_len)); + ipv4h->packet_id = be16_t(0); + ipv4h->fragment_offset = be16_t(0); + ipv4h->time_to_live = 64; + ipv4h->next_proto_id = Ipv4::Proto::kTcp; + ipv4h->hdr_checksum = 0; + ipv4h->src_addr = remote_addr_; + ipv4h->dst_addr = local_addr_; + + // TCP header. + auto* tcph = pkt->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + tcph->src_port = remote_port_; + tcph->dst_port = local_port_; + tcph->seq_num = be32_t(seq); + tcph->ack_num = be32_t(ack); + tcph->set_header_length(static_cast(tcp_hdr_len)); + tcph->flags = flags; + tcph->window = be16_t(65535); + tcph->checksum = 0; + tcph->urgent_ptr = be16_t(0); + + // TCP options (if any). + if (tcp_opts != nullptr && tcp_opts_len > 0) { + uint8_t* opt_dst = reinterpret_cast(tcph) + sizeof(Tcp); + std::memcpy(opt_dst, tcp_opts, tcp_opts_len); + } + + // Payload. + if (payload != nullptr && payload_len > 0) { + uint8_t* dst = pkt->head_data( + static_cast(hdr_len)); + std::memcpy(dst, payload, payload_len); + } + + return pkt; + } + + /// Build the standard MSS option bytes (Kind=2, Length=4, MSS value). + static std::vector MakeMSSOption(uint16_t mss) { + std::vector opt(4); + opt[0] = 2; // Kind: MSS + opt[1] = 4; // Length + uint16_t mss_net = htobe16(mss); + std::memcpy(&opt[2], &mss_net, sizeof(mss_net)); + return opt; + } + + /// Create a MsgBuf train of the given size, filled with sequential bytes. + shm::MsgBuf* CreateMsg(size_t msg_len) { + std::vector data(msg_len); + for (size_t i = 0; i < msg_len; i++) { + data[i] = static_cast(i & 0xFF); + } + return CreateMsgFromData(data); + } + + shm::MsgBuf* CreateMsgFromData(const std::vector& data) { + const auto msgbuf_size = channel_->GetUsableBufSize(); + auto num_msgbufs = + (data.size() + msgbuf_size - 1) / msgbuf_size; + + shm::MsgBuf* head = nullptr; + shm::MsgBuf* tail = nullptr; + size_t data_offset = 0; + + while (num_msgbufs > 0) { + shm::MsgBufBatch batch; + auto msgbuf_nr = + std::min(num_msgbufs, static_cast(batch.GetRoom())); + CHECK(channel_->MsgBufBulkAlloc(&batch, msgbuf_nr)); + + for (auto i = 0; i < batch.GetSize(); i++) { + auto* msgbuf = batch[i]; + auto nbytes = std::min(static_cast(msgbuf_size), + data.size() - data_offset); + utils::Copy(CHECK_NOTNULL(msgbuf->append(nbytes)), + &data[data_offset], nbytes); + data_offset += nbytes; + + if (head == nullptr) { + head = msgbuf; + tail = msgbuf; + } else { + tail->set_next(msgbuf); + tail = msgbuf; + } + } + num_msgbufs -= msgbuf_nr; + batch.Clear(); + } + + head->set_msg_length(data.size()); + head->set_last(tail->index()); + head->mark_first(); + tail->mark_last(); + return head; + } + + // ── Members ── + Ipv4::Address local_addr_; + Ipv4::Address remote_addr_; + Tcp::Port local_port_; + Tcp::Port remote_port_; + Ethernet::Address local_mac_; + Ethernet::Address remote_mac_; + shm::ChannelManager channel_mgr_; + std::shared_ptr channel_; + std::unique_ptr pkt_pool_; + + bool callback_called_; + bool callback_success_; +}; + +// Static member definitions. +std::shared_ptr TcpFlowTest::pmd_port_; +dpdk::TxRing* TcpFlowTest::txring_; + +// ═══════════════════════════════════════════════════════════════ +// Sequence Number Comparison Helpers +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, SeqComparisons) { + // Basic ordering. + EXPECT_TRUE(TcpFlow::SeqLt(1, 2)); + EXPECT_FALSE(TcpFlow::SeqLt(2, 1)); + EXPECT_TRUE(TcpFlow::SeqLeq(2, 2)); + EXPECT_TRUE(TcpFlow::SeqGt(3, 2)); + EXPECT_FALSE(TcpFlow::SeqGt(2, 3)); + EXPECT_TRUE(TcpFlow::SeqGeq(3, 3)); + + // Wrapping: 0xFFFFFFFF < 0x00000000 in TCP sequence space. + EXPECT_TRUE(TcpFlow::SeqLt(0xFFFFFFFF, 0)); + EXPECT_TRUE(TcpFlow::SeqGt(0, 0xFFFFFFFF)); + EXPECT_TRUE(TcpFlow::SeqLt(0xFFFFFFFE, 0x00000001)); +} + +// ═══════════════════════════════════════════════════════════════ +// Construction & Initial State +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, InitialState) { + auto flow = MakeFlow(); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); + EXPECT_EQ(flow->channel(), channel_.get()); + EXPECT_EQ(flow->key().local_addr, local_addr_); + EXPECT_EQ(flow->key().remote_addr, remote_addr_); + EXPECT_EQ(flow->key().local_port, local_port_); + EXPECT_EQ(flow->key().remote_port, remote_port_); + EXPECT_EQ(flow->peer_mss_, static_cast(TcpFlow::kDefaultMSS)); + EXPECT_EQ(flow->snd_nxt_, flow->snd_isn_); + EXPECT_EQ(flow->snd_una_, flow->snd_isn_); +} + +// ═══════════════════════════════════════════════════════════════ +// Active Open (Client) Handshake +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, ActiveOpen_InitiateHandshake) { + auto flow = MakeFlow(); + + // InitiateHandshake should send SYN and transition to SYN_SENT. + flow->InitiateHandshake(); + EXPECT_EQ(flow->state(), TcpFlow::State::kSynSent); + EXPECT_TRUE(flow->rto_active_); + // SYN consumes one seq: snd_nxt_ should be snd_isn_ + 1. + EXPECT_EQ(flow->snd_nxt_, flow->snd_isn_ + 1); +} + +TEST_F(TcpFlowTest, ActiveOpen_FullHandshake) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + + uint32_t client_isn = flow->snd_isn_; + uint32_t server_isn = 42000; + + // Server responds with SYN-ACK, acknowledging our SYN. + auto* syn_ack = MakePacket(server_isn, client_isn + 1, + Tcp::kSyn | Tcp::kAck); + flow->InputPacket(syn_ack); + + // Should be ESTABLISHED. + EXPECT_EQ(flow->state(), TcpFlow::State::kEstablished); + EXPECT_EQ(flow->rcv_nxt_, server_isn + 1); + EXPECT_EQ(flow->snd_una_, client_isn + 1); + EXPECT_FALSE(flow->rto_active_); + // Callback should have been invoked with success. + EXPECT_TRUE(callback_called_); + EXPECT_TRUE(callback_success_); + + dpdk::Packet::Free(syn_ack); +} + +TEST_F(TcpFlowTest, ActiveOpen_SynAckWrongAck) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + + uint32_t client_isn = flow->snd_isn_; + uint32_t server_isn = 42000; + + // Wrong ack — should stay in SYN_SENT. + auto* bad_syn_ack = MakePacket(server_isn, client_isn + 99, + Tcp::kSyn | Tcp::kAck); + flow->InputPacket(bad_syn_ack); + + EXPECT_EQ(flow->state(), TcpFlow::State::kSynSent); + EXPECT_FALSE(callback_called_); + + dpdk::Packet::Free(bad_syn_ack); +} + +// ═══════════════════════════════════════════════════════════════ +// Passive Open (Server) Handshake +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, PassiveOpen_StartPassiveOpen) { + auto flow = MakeFlow(); + flow->StartPassiveOpen(); + EXPECT_EQ(flow->state(), TcpFlow::State::kListen); +} + +TEST_F(TcpFlowTest, PassiveOpen_FullHandshake) { + auto flow = MakeFlow(); + flow->StartPassiveOpen(); + + uint32_t client_isn = 12345; + uint32_t server_isn = flow->snd_isn_; + + // Client sends SYN. + auto* syn = MakePacket(client_isn, 0, Tcp::kSyn); + flow->InputPacket(syn); + EXPECT_EQ(flow->state(), TcpFlow::State::kSynReceived); + EXPECT_EQ(flow->rcv_nxt_, client_isn + 1); + EXPECT_TRUE(flow->rto_active_); + + // Client acknowledges our SYN-ACK. + auto* ack = MakePacket(client_isn + 1, server_isn + 1, Tcp::kAck); + flow->InputPacket(ack); + EXPECT_EQ(flow->state(), TcpFlow::State::kEstablished); + EXPECT_EQ(flow->snd_una_, server_isn + 1); + EXPECT_FALSE(flow->rto_active_); + + dpdk::Packet::Free(syn); + dpdk::Packet::Free(ack); +} + +TEST_F(TcpFlowTest, PassiveOpen_SynWithMSSOption) { + auto flow = MakeFlow(); + flow->StartPassiveOpen(); + + uint32_t client_isn = 12345; + auto mss_opt = MakeMSSOption(1460); + + // Client sends SYN with MSS option. + auto* syn = MakePacket(client_isn, 0, Tcp::kSyn, + nullptr, 0, + mss_opt.data(), mss_opt.size()); + flow->InputPacket(syn); + + EXPECT_EQ(flow->state(), TcpFlow::State::kSynReceived); + EXPECT_EQ(flow->peer_mss_, 1460); + + dpdk::Packet::Free(syn); +} + +TEST_F(TcpFlowTest, PassiveOpen_PiggybackOnAck) { + // Linux kernel can piggyback data on the completing handshake ACK. + auto flow = MakeFlow(); + flow->StartPassiveOpen(); + + uint32_t client_isn = 12345; + uint32_t server_isn = flow->snd_isn_; + + // SYN. + auto* syn = MakePacket(client_isn, 0, Tcp::kSyn); + flow->InputPacket(syn); + + // ACK with some payload. + uint8_t payload[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + auto* ack = MakePacket(client_isn + 1, server_isn + 1, Tcp::kAck, + payload, sizeof(payload)); + flow->InputPacket(ack); + + EXPECT_EQ(flow->state(), TcpFlow::State::kEstablished); + // Data should have been consumed. + EXPECT_EQ(flow->rcv_nxt_, client_isn + 1 + sizeof(payload)); + + dpdk::Packet::Free(syn); + dpdk::Packet::Free(ack); +} + +// ═══════════════════════════════════════════════════════════════ +// RST Handling +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, RstInEstablished) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + + auto* rst = MakePacket(1000, 0, Tcp::kRst); + flow->InputPacket(rst); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); + + dpdk::Packet::Free(rst); +} + +TEST_F(TcpFlowTest, RstInSynSent) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + + auto* rst = MakePacket(0, 0, Tcp::kRst); + flow->InputPacket(rst); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); + + dpdk::Packet::Free(rst); +} + +// ═══════════════════════════════════════════════════════════════ +// MSS Option Parsing +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, ParseTcpOptions_MSSOnly) { + auto flow = MakeFlow(); + // Build a fake TCP header with MSS option. + uint8_t buf[24] = {}; + auto* tcph = reinterpret_cast(buf); + tcph->set_header_length(24); + // MSS option at offset 20 (after base header). + buf[20] = 2; // Kind: MSS + buf[21] = 4; // Length + uint16_t mss_net = htobe16(8960); + std::memcpy(&buf[22], &mss_net, 2); + + flow->ParseTcpOptions(tcph, 24); + EXPECT_EQ(flow->peer_mss_, 8960); +} + +TEST_F(TcpFlowTest, ParseTcpOptions_MSSWithNopPadding) { + auto flow = MakeFlow(); + // Build TCP header with NOP + NOP + MSS + EOL. + uint8_t buf[28] = {}; + auto* tcph = reinterpret_cast(buf); + tcph->set_header_length(28); + buf[20] = 1; // NOP + buf[21] = 1; // NOP + buf[22] = 2; // Kind: MSS + buf[23] = 4; // Length + uint16_t mss_net = htobe16(1460); + std::memcpy(&buf[24], &mss_net, 2); + buf[26] = 0; // EOL + + flow->ParseTcpOptions(tcph, 28); + EXPECT_EQ(flow->peer_mss_, 1460); +} + +TEST_F(TcpFlowTest, ParseTcpOptions_NoOptions) { + auto flow = MakeFlow(); + uint8_t buf[20] = {}; + auto* tcph = reinterpret_cast(buf); + tcph->set_header_length(20); + + // peer_mss_ should remain at default. + flow->ParseTcpOptions(tcph, 20); + EXPECT_EQ(flow->peer_mss_, static_cast(TcpFlow::kDefaultMSS)); +} + +TEST_F(TcpFlowTest, ParseTcpOptions_ZeroMSS_FallsBackToDefault) { + auto flow = MakeFlow(); + uint8_t buf[24] = {}; + auto* tcph = reinterpret_cast(buf); + tcph->set_header_length(24); + buf[20] = 2; // Kind: MSS + buf[21] = 4; // Length + buf[22] = 0; // MSS = 0 + buf[23] = 0; + + flow->ParseTcpOptions(tcph, 24); + EXPECT_EQ(flow->peer_mss_, static_cast(TcpFlow::kDefaultMSS)); +} + +// ═══════════════════════════════════════════════════════════════ +// TCP Header Length Validation +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, InputPacket_InvalidHeaderLength) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + + // Craft a packet with TCP header length of 16 (< 20, invalid). + auto* pkt = MakePacket(1000, 0, Tcp::kAck); + auto* tcph = pkt->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + tcph->set_header_length(16); // Invalid: below minimum. + + uint32_t old_rcv_nxt = flow->rcv_nxt_; + flow->InputPacket(pkt); + // Should be a no-op — state unchanged, no data consumed. + EXPECT_EQ(flow->state(), TcpFlow::State::kEstablished); + EXPECT_EQ(flow->rcv_nxt_, old_rcv_nxt); + + dpdk::Packet::Free(pkt); +} + +// ═══════════════════════════════════════════════════════════════ +// Data TX Path (OutputMessage) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, OutputMessage_SmallMessage) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 5000; + uint32_t initial_snd_nxt = flow->snd_nxt_; + + // Create a small message (64 bytes). + auto* msgbuf = CreateMsg(64); + + flow->OutputMessage(msgbuf); + + // snd_nxt_ should advance by 4 (length prefix) + 64 (payload). + EXPECT_EQ(flow->snd_nxt_, initial_snd_nxt + 4 + 64); + // RTO should be armed. + EXPECT_TRUE(flow->rto_active_); +} + +TEST_F(TcpFlowTest, OutputMessage_NotEstablished) { + auto flow = MakeFlow(); + // Try to send in CLOSED state — should fail gracefully. + flow->state_ = TcpFlow::State::kClosed; + + auto* msgbuf = CreateMsg(64); + uint32_t initial_snd_nxt = flow->snd_nxt_; + + flow->OutputMessage(msgbuf); + // snd_nxt_ should not change. + EXPECT_EQ(flow->snd_nxt_, initial_snd_nxt); +} + +TEST_F(TcpFlowTest, OutputMessage_LargerThanMSS) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 5000; + flow->peer_mss_ = 100; // Small MSS for segmentation. + uint32_t initial_snd_nxt = flow->snd_nxt_; + + // Create a message larger than MSS. + const size_t msg_len = 300; + auto* msgbuf = CreateMsg(msg_len); + + flow->OutputMessage(msgbuf); + + // Total bytes on wire: 4 (prefix) + 300 (payload) = 304 bytes. + EXPECT_EQ(flow->snd_nxt_, initial_snd_nxt + 4 + msg_len); +} + +// ═══════════════════════════════════════════════════════════════ +// Data RX Path (ConsumePayload) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, ConsumePayload_SingleMessage) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + + // Build a framed message: 4-byte length prefix + payload. + const std::string payload_str = "Hello, TCP world!"; + const uint32_t msg_len = static_cast(payload_str.size()); + uint32_t net_len = htobe32(msg_len); + + std::vector wire_data(TcpFlow::kMsgLenPrefixSize + msg_len); + std::memcpy(wire_data.data(), &net_len, TcpFlow::kMsgLenPrefixSize); + std::memcpy(wire_data.data() + TcpFlow::kMsgLenPrefixSize, + payload_str.data(), msg_len); + + flow->ConsumePayload(wire_data.data(), wire_data.size()); + + // Message should have been delivered to the channel. + // Verify by receiving from the channel. + std::vector rx_buf(msg_len); + MachnetIovec_t iov; + iov.base = rx_buf.data(); + iov.len = rx_buf.size(); + MachnetMsgHdr_t msghdr; + msghdr.flags = 0; + msghdr.flow_info = {0, 0, 0, 0}; + msghdr.msg_iov = &iov; + msghdr.msg_iovlen = 1; + + auto ret = machnet_recvmsg(channel_->ctx(), &msghdr); + EXPECT_EQ(ret, 1) << "Message should have been delivered to channel"; + EXPECT_EQ(std::string(rx_buf.begin(), rx_buf.end()), payload_str); +} + +TEST_F(TcpFlowTest, ConsumePayload_SplitAcrossPackets) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + + // Build a framed message. + const uint32_t msg_len = 100; + uint32_t net_len = htobe32(msg_len); + std::vector payload(msg_len); + for (uint32_t i = 0; i < msg_len; i++) payload[i] = static_cast(i); + + std::vector wire_data(TcpFlow::kMsgLenPrefixSize + msg_len); + std::memcpy(wire_data.data(), &net_len, TcpFlow::kMsgLenPrefixSize); + std::memcpy(wire_data.data() + TcpFlow::kMsgLenPrefixSize, + payload.data(), msg_len); + + // Split at the middle of the length prefix (2 bytes). + flow->ConsumePayload(wire_data.data(), 2); + // Should not have a pending message yet. + EXPECT_EQ(flow->rx_pending_msg_len_, 0u); + + // Send the rest. + flow->ConsumePayload(wire_data.data() + 2, wire_data.size() - 2); + + // Verify message was delivered. + std::vector rx_buf(msg_len); + MachnetIovec_t iov; + iov.base = rx_buf.data(); + iov.len = rx_buf.size(); + MachnetMsgHdr_t msghdr; + msghdr.flags = 0; + msghdr.flow_info = {0, 0, 0, 0}; + msghdr.msg_iov = &iov; + msghdr.msg_iovlen = 1; + + auto ret = machnet_recvmsg(channel_->ctx(), &msghdr); + EXPECT_EQ(ret, 1); + EXPECT_EQ(rx_buf, payload); +} + +TEST_F(TcpFlowTest, ConsumePayload_MultipleMessages) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + + // Two back-to-back messages in a single payload. + const uint32_t msg1_len = 10; + const uint32_t msg2_len = 20; + + std::vector msg1_payload(msg1_len, 0xAA); + std::vector msg2_payload(msg2_len, 0xBB); + + std::vector wire_data; + uint32_t net_len; + + // Message 1. + net_len = htobe32(msg1_len); + wire_data.insert(wire_data.end(), reinterpret_cast(&net_len), + reinterpret_cast(&net_len) + 4); + wire_data.insert(wire_data.end(), msg1_payload.begin(), msg1_payload.end()); + + // Message 2. + net_len = htobe32(msg2_len); + wire_data.insert(wire_data.end(), reinterpret_cast(&net_len), + reinterpret_cast(&net_len) + 4); + wire_data.insert(wire_data.end(), msg2_payload.begin(), msg2_payload.end()); + + flow->ConsumePayload(wire_data.data(), wire_data.size()); + + // Read message 1. + { + std::vector rx_buf(msg1_len); + MachnetIovec_t iov{rx_buf.data(), rx_buf.size()}; + MachnetMsgHdr_t msghdr; + msghdr.flags = 0; + msghdr.flow_info = {0, 0, 0, 0}; + msghdr.msg_iov = &iov; + msghdr.msg_iovlen = 1; + EXPECT_EQ(machnet_recvmsg(channel_->ctx(), &msghdr), 1); + EXPECT_EQ(rx_buf, msg1_payload); + } + // Read message 2. + { + std::vector rx_buf(msg2_len); + MachnetIovec_t iov{rx_buf.data(), rx_buf.size()}; + MachnetMsgHdr_t msghdr; + msghdr.flags = 0; + msghdr.flow_info = {0, 0, 0, 0}; + msghdr.msg_iov = &iov; + msghdr.msg_iovlen = 1; + EXPECT_EQ(machnet_recvmsg(channel_->ctx(), &msghdr), 1); + EXPECT_EQ(rx_buf, msg2_payload); + } +} + +// ═══════════════════════════════════════════════════════════════ +// HandleEstablished — Data + ACK +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, HandleEstablished_DataPacket) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + flow->snd_una_ = 500; + flow->snd_nxt_ = 500; + + // Build a framed message as payload. + const uint32_t msg_len = 16; + uint32_t net_len = htobe32(msg_len); + std::vector payload(TcpFlow::kMsgLenPrefixSize + msg_len); + std::memcpy(payload.data(), &net_len, TcpFlow::kMsgLenPrefixSize); + for (uint32_t i = 0; i < msg_len; i++) { + payload[TcpFlow::kMsgLenPrefixSize + i] = static_cast(i); + } + + auto* pkt = MakePacket(1000, 500, Tcp::kAck | Tcp::kPsh, + payload.data(), payload.size()); + flow->InputPacket(pkt); + + // rcv_nxt should advance by the payload size. + EXPECT_EQ(flow->rcv_nxt_, 1000 + payload.size()); + + // Read the delivered message. + std::vector rx_buf(msg_len); + MachnetIovec_t iov{rx_buf.data(), rx_buf.size()}; + MachnetMsgHdr_t msghdr; + msghdr.flags = 0; + msghdr.flow_info = {0, 0, 0, 0}; + msghdr.msg_iov = &iov; + msghdr.msg_iovlen = 1; + EXPECT_EQ(machnet_recvmsg(channel_->ctx(), &msghdr), 1); + + for (uint32_t i = 0; i < msg_len; i++) { + EXPECT_EQ(rx_buf[i], static_cast(i)); + } + + dpdk::Packet::Free(pkt); +} + +// ═══════════════════════════════════════════════════════════════ +// HandleEstablished — Retransmitted SYN +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, HandleEstablished_RetransmittedSyn) { + // In ESTABLISHED, receiving a SYN should re-send ACK (kernel missed it). + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2000; + flow->snd_nxt_ = 3000; + flow->snd_una_ = 3000; + + auto* syn = MakePacket(1999, 0, Tcp::kSyn); + auto state_before = flow->state(); + flow->InputPacket(syn); + + // State should not change — we just re-ACK. + EXPECT_EQ(flow->state(), state_before); + EXPECT_EQ(flow->rcv_nxt_, 2000u); + + dpdk::Packet::Free(syn); +} + +// ═══════════════════════════════════════════════════════════════ +// Retransmission Overlap (ProcessInOrderPayload) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, ProcessInOrderPayload_ExactMatch) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + + // Segment exactly at rcv_nxt_. + uint8_t data[] = {1, 2, 3, 4, 5}; + auto* pkt = MakePacket(1000, 0, Tcp::kAck, data, sizeof(data)); + + bool consumed = flow->ProcessInOrderPayload( + pkt, 1000, sizeof(data), + sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp)); + + EXPECT_TRUE(consumed); + EXPECT_EQ(flow->rcv_nxt_, 1005u); + + dpdk::Packet::Free(pkt); +} + +TEST_F(TcpFlowTest, ProcessInOrderPayload_PartialOverlap) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1003; // We already received bytes 1000-1002. + + // Retransmitted segment starts at 1000 but extends to 1009. + uint8_t data[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + auto* pkt = MakePacket(1000, 0, Tcp::kAck, data, sizeof(data)); + + bool consumed = flow->ProcessInOrderPayload( + pkt, 1000, sizeof(data), + sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp)); + + EXPECT_TRUE(consumed); + // Should have consumed only the new part: bytes 1003-1009 = 7 bytes. + EXPECT_EQ(flow->rcv_nxt_, 1010u); + + dpdk::Packet::Free(pkt); +} + +TEST_F(TcpFlowTest, ProcessInOrderPayload_PureDuplicate) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1010; // Already past this segment. + + uint8_t data[5] = {1, 2, 3, 4, 5}; + auto* pkt = MakePacket(1005, 0, Tcp::kAck, data, sizeof(data)); + + bool consumed = flow->ProcessInOrderPayload( + pkt, 1005, sizeof(data), + sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp)); + + // Pure duplicate — nothing consumed. + EXPECT_FALSE(consumed); + EXPECT_EQ(flow->rcv_nxt_, 1010u); + + dpdk::Packet::Free(pkt); +} + +TEST_F(TcpFlowTest, ProcessInOrderPayload_EmptyPayload) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + + auto* pkt = MakePacket(1000, 0, Tcp::kAck); + + bool consumed = flow->ProcessInOrderPayload( + pkt, 1000, 0, sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp)); + + EXPECT_FALSE(consumed); + EXPECT_EQ(flow->rcv_nxt_, 1000u); + + dpdk::Packet::Free(pkt); +} + +// ═══════════════════════════════════════════════════════════════ +// FIN Handling +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, FIN_InEstablished) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2000; + flow->snd_una_ = 3000; + flow->snd_nxt_ = 3000; + + // FIN at the expected sequence position (no payload). + auto* fin_pkt = MakePacket(2000, 3000, Tcp::kFin | Tcp::kAck); + flow->InputPacket(fin_pkt); + + EXPECT_EQ(flow->state(), TcpFlow::State::kCloseWait); + EXPECT_EQ(flow->rcv_nxt_, 2001u); // FIN consumes one seq. + + dpdk::Packet::Free(fin_pkt); +} + +TEST_F(TcpFlowTest, FIN_WithData) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2000; + flow->snd_una_ = 3000; + flow->snd_nxt_ = 3000; + + // FIN with 10 bytes of payload. FIN is at seq 2000+10=2010. + uint8_t payload[10] = {}; + auto* fin_pkt = MakePacket(2000, 3000, Tcp::kFin | Tcp::kAck, + payload, sizeof(payload)); + flow->InputPacket(fin_pkt); + + EXPECT_EQ(flow->state(), TcpFlow::State::kCloseWait); + EXPECT_EQ(flow->rcv_nxt_, 2011u); // 10 data + 1 FIN. + + dpdk::Packet::Free(fin_pkt); +} + +TEST_F(TcpFlowTest, FIN_RetransmittedFin) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2001; // Already processed FIN at 2000. + flow->snd_una_ = 3000; + flow->snd_nxt_ = 3000; + + // Retransmitted FIN at seq 2000 (we already processed it). + auto* dup_fin = MakePacket(2000, 3000, Tcp::kFin | Tcp::kAck); + flow->InputPacket(dup_fin); + + // Should stay in ESTABLISHED (or wherever it was) — fin_seq < rcv_nxt. + // Actually, we're at seq 2000+0=2000 < 2001, so it's a retransmit → re-ACK. + // State should not change again. + EXPECT_EQ(flow->rcv_nxt_, 2001u); + + dpdk::Packet::Free(dup_fin); +} + +// ═══════════════════════════════════════════════════════════════ +// Shutdown (Active Close) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, ShutDown_FromEstablished) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + uint32_t old_snd_nxt = flow->snd_nxt_; + + flow->ShutDown(); + + EXPECT_EQ(flow->state(), TcpFlow::State::kFinWait1); + // FIN consumes one seq. + EXPECT_EQ(flow->snd_nxt_, old_snd_nxt + 1); +} + +TEST_F(TcpFlowTest, ShutDown_FromCloseWait) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kCloseWait; + flow->rcv_nxt_ = 1000; + uint32_t old_snd_nxt = flow->snd_nxt_; + + flow->ShutDown(); + + EXPECT_EQ(flow->state(), TcpFlow::State::kLastAck); + EXPECT_EQ(flow->snd_nxt_, old_snd_nxt + 1); +} + +TEST_F(TcpFlowTest, ShutDown_FromSynSent) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + + flow->ShutDown(); + + // Non-established state → send RST, go to CLOSED. + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); +} + +// ═══════════════════════════════════════════════════════════════ +// Full Graceful Close (4-way) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, GracefulClose_FinWait1_to_TimeWait) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2000; + flow->snd_nxt_ = 3000; + flow->snd_una_ = 3000; + + // We initiate close. + flow->ShutDown(); + EXPECT_EQ(flow->state(), TcpFlow::State::kFinWait1); + EXPECT_EQ(flow->snd_nxt_, 3001u); + + // Peer ACKs our FIN. + auto* ack = MakePacket(2000, 3001, Tcp::kAck); + flow->InputPacket(ack); + EXPECT_EQ(flow->state(), TcpFlow::State::kFinWait2); + + // Peer sends FIN. + auto* peer_fin = MakePacket(2000, 3001, Tcp::kFin | Tcp::kAck); + flow->InputPacket(peer_fin); + EXPECT_EQ(flow->state(), TcpFlow::State::kTimeWait); + EXPECT_EQ(flow->rcv_nxt_, 2001u); + + dpdk::Packet::Free(ack); + dpdk::Packet::Free(peer_fin); +} + +TEST_F(TcpFlowTest, GracefulClose_LastAck) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 2000; + flow->snd_nxt_ = 3000; + flow->snd_una_ = 3000; + + // Peer sends FIN first. + auto* peer_fin = MakePacket(2000, 3000, Tcp::kFin | Tcp::kAck); + flow->InputPacket(peer_fin); + EXPECT_EQ(flow->state(), TcpFlow::State::kCloseWait); + + // We respond with FIN. + flow->ShutDown(); + EXPECT_EQ(flow->state(), TcpFlow::State::kLastAck); + + // Peer ACKs our FIN. + auto* final_ack = MakePacket(2001, flow->snd_nxt_, Tcp::kAck); + flow->InputPacket(final_ack); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); + + dpdk::Packet::Free(peer_fin); + dpdk::Packet::Free(final_ack); +} + +// ═══════════════════════════════════════════════════════════════ +// PeriodicCheck — RTO and TIME_WAIT +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, PeriodicCheck_ClosedReturnsFalse) { + auto flow = MakeFlow(); + EXPECT_FALSE(flow->PeriodicCheck()); +} + +TEST_F(TcpFlowTest, PeriodicCheck_TimeWaitCountdown) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kTimeWait; + flow->time_wait_remaining_ = TcpFlow::kTimeWaitTicks; + + for (uint32_t i = 0; i < TcpFlow::kTimeWaitTicks; i++) { + EXPECT_TRUE(flow->PeriodicCheck()); + } + // One more should transition to CLOSED and return false. + EXPECT_FALSE(flow->PeriodicCheck()); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); +} + +TEST_F(TcpFlowTest, PeriodicCheck_RTO_SynSentRetransmit) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + EXPECT_TRUE(flow->rto_active_); + + // Burn through the RTO timer. + for (uint32_t i = 0; i < TcpFlow::kInitialRTO; i++) { + EXPECT_TRUE(flow->PeriodicCheck()); + } + // Next check should trigger retransmission. + uint32_t retransmit_before = flow->retransmit_count_; + EXPECT_TRUE(flow->PeriodicCheck()); + EXPECT_GT(flow->retransmit_count_, retransmit_before); + EXPECT_EQ(flow->state(), TcpFlow::State::kSynSent); +} + +TEST_F(TcpFlowTest, PeriodicCheck_MaxRetransmissions) { + auto flow = MakeFlow(); + flow->InitiateHandshake(); + + // Exhaust all retransmissions. + flow->retransmit_count_ = TcpFlow::kMaxRetransmissions; + flow->rto_remaining_ = 0; + + EXPECT_FALSE(flow->PeriodicCheck()); + EXPECT_EQ(flow->state(), TcpFlow::State::kClosed); + // Callback should be invoked with failure (since we were in SYN_SENT). + EXPECT_TRUE(callback_called_); + EXPECT_FALSE(callback_success_); +} + +TEST_F(TcpFlowTest, PeriodicCheck_EstablishedNoRTO) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rto_active_ = false; + + // Should just return true without changing state. + EXPECT_TRUE(flow->PeriodicCheck()); + EXPECT_EQ(flow->state(), TcpFlow::State::kEstablished); +} + +// ═══════════════════════════════════════════════════════════════ +// AdvanceSndUna +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, AdvanceSndUna_NormalAck) { + auto flow = MakeFlow(); + flow->snd_una_ = 100; + flow->snd_nxt_ = 200; + flow->rto_active_ = true; + flow->retransmit_count_ = 2; + + flow->AdvanceSndUna(150); + EXPECT_EQ(flow->snd_una_, 150u); + EXPECT_EQ(flow->retransmit_count_, 0u); + EXPECT_TRUE(flow->rto_active_); // Still unacked data. +} + +TEST_F(TcpFlowTest, AdvanceSndUna_AllAcked) { + auto flow = MakeFlow(); + flow->snd_una_ = 100; + flow->snd_nxt_ = 200; + flow->rto_active_ = true; + + flow->AdvanceSndUna(200); + EXPECT_EQ(flow->snd_una_, 200u); + EXPECT_FALSE(flow->rto_active_); // All data acknowledged. +} + +TEST_F(TcpFlowTest, AdvanceSndUna_StaleAck) { + auto flow = MakeFlow(); + flow->snd_una_ = 100; + flow->snd_nxt_ = 200; + + // ACK for something already acknowledged — should be a no-op. + flow->AdvanceSndUna(50); + EXPECT_EQ(flow->snd_una_, 100u); +} + +TEST_F(TcpFlowTest, AdvanceSndUna_FutureAck) { + auto flow = MakeFlow(); + flow->snd_una_ = 100; + flow->snd_nxt_ = 200; + + // ACK beyond snd_nxt_ — should be ignored. + flow->AdvanceSndUna(250); + EXPECT_EQ(flow->snd_una_, 100u); +} + +// ═══════════════════════════════════════════════════════════════ +// StateToString +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, StateToString) { + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kClosed), "CLOSED"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kListen), "LISTEN"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kSynSent), "SYN_SENT"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kSynReceived), + "SYN_RECEIVED"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kEstablished), + "ESTABLISHED"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kFinWait1), + "FIN_WAIT_1"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kFinWait2), + "FIN_WAIT_2"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kCloseWait), + "CLOSE_WAIT"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kLastAck), + "LAST_ACK"); + EXPECT_STREQ(TcpFlow::StateToString(TcpFlow::State::kTimeWait), + "TIME_WAIT"); +} + +// ═══════════════════════════════════════════════════════════════ +// Match +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, Match_CorrectPacket) { + auto flow = MakeFlow(); + auto* pkt = MakePacket(0, 0, Tcp::kSyn); + EXPECT_TRUE(flow->Match(pkt)); + dpdk::Packet::Free(pkt); +} + +TEST_F(TcpFlowTest, Match_WrongSrcAddr) { + auto flow = MakeFlow(); + auto* pkt = MakePacket(0, 0, Tcp::kSyn); + // Tamper with src IP. + auto* ipv4h = pkt->head_data(sizeof(Ethernet)); + Ipv4::Address wrong_ip; + wrong_ip.FromString("192.168.1.1"); + ipv4h->src_addr = wrong_ip; + EXPECT_FALSE(flow->Match(pkt)); + dpdk::Packet::Free(pkt); +} + +// ═══════════════════════════════════════════════════════════════ +// Ethernet Padding Resilience (IP total_length vs packet->length()) +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, EthernetPadding_SmallPacket) { + // When Ethernet pads a frame to 60 bytes, packet->length() overcounts. + // The TCP InputPacket path should use IP total_length instead. + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + flow->snd_una_ = 500; + flow->snd_nxt_ = 500; + + // Create a packet with 2 bytes of payload. + uint8_t payload[2] = {0x41, 0x42}; + auto* pkt = MakePacket(1000, 500, Tcp::kAck, payload, 2); + + // Now pad the packet to 60 bytes (simulating Ethernet minimum). + // The actual payload_len from IP header is correct (2 bytes), but + // packet->length() may be larger. + size_t actual_pkt_len = pkt->length(); + // Note: packet->length() may be >= 60 due to Ethernet padding. + // The test verifies that the IP total_length path works correctly + // regardless of packet->length(). + (void)actual_pkt_len; + + flow->InputPacket(pkt); + + // Should have consumed exactly 2 bytes, not more. + EXPECT_EQ(flow->rcv_nxt_, 1002u); + + dpdk::Packet::Free(pkt); +} + +// ═══════════════════════════════════════════════════════════════ +// IP total_length too small +// ═══════════════════════════════════════════════════════════════ + +TEST_F(TcpFlowTest, InputPacket_IpTotalLengthTooSmall) { + auto flow = MakeFlow(); + flow->state_ = TcpFlow::State::kEstablished; + flow->rcv_nxt_ = 1000; + + auto* pkt = MakePacket(1000, 500, Tcp::kAck); + // Corrupt IP total_length to be smaller than IP + TCP header. + auto* ipv4h = pkt->head_data(sizeof(Ethernet)); + ipv4h->total_length = be16_t(10); // Way too small. + + uint32_t old_rcv_nxt = flow->rcv_nxt_; + flow->InputPacket(pkt); + EXPECT_EQ(flow->rcv_nxt_, old_rcv_nxt); // No change. + + dpdk::Packet::Free(pkt); +} + +} // namespace flow +} // namespace net +} // namespace juggler + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + testing::InitGoogleTest(&argc, argv); + FLAGS_logtostderr = 1; + + // Initialize DPDK with a null virtual device and no PCI. + auto kEalOpts = juggler::utils::CmdLineOpts( + {"", "-c", "0x1", "-n", "6", "--proc-type=auto", "-m", "1024", + "--log-level", "8", "--vdev=net_null0,copy=1", "--no-pci"}); + auto d = juggler::dpdk::Dpdk(); + d.InitDpdk(kEalOpts); + + return RUN_ALL_TESTS(); +} diff --git a/src/ext/machnet.c b/src/ext/machnet.c index d104c7f5..275b5d64 100644 --- a/src/ext/machnet.c +++ b/src/ext/machnet.c @@ -410,7 +410,7 @@ void *machnet_attach() { } int machnet_connect(void *channel_ctx, const char *src_ip, const char *dst_ip, - uint16_t dst_port, MachnetFlow_t *flow) { + uint16_t dst_port, MachnetFlow_t *flow, int protocol) { assert(flow != NULL); MachnetChannelCtx_t *ctx = channel_ctx; @@ -425,7 +425,8 @@ int machnet_connect(void *channel_ctx, const char *src_ip, const char *dst_ip, MachnetCtrlQueueEntry_t req; memset(&req, 0, sizeof(req)); req.id = ctx->ctrl_ctx.req_id++; - req.opcode = MACHNET_CTRL_OP_CREATE_FLOW; + req.opcode = (protocol == MACHNET_PROTO_TCP) ? MACHNET_CTRL_OP_TCP_CREATE_FLOW + : MACHNET_CTRL_OP_CREATE_FLOW; req.flow_info.src_ip = ntohl(inet_addr(src_ip)); req.flow_info.dst_ip = ntohl(inet_addr(dst_ip)); req.flow_info.dst_port = dst_port; @@ -466,7 +467,7 @@ int machnet_connect(void *channel_ctx, const char *src_ip, const char *dst_ip, } int machnet_listen(void *channel_ctx, const char *local_ip, - uint16_t local_port) { + uint16_t local_port, int protocol) { assert(channel_ctx != NULL); MachnetChannelCtx_t *ctx = channel_ctx; @@ -478,7 +479,8 @@ int machnet_listen(void *channel_ctx, const char *local_ip, MachnetCtrlQueueEntry_t req; memset(&req, 0, sizeof(req)); req.id = ctx->ctrl_ctx.req_id++; - req.opcode = MACHNET_CTRL_OP_LISTEN; + req.opcode = (protocol == MACHNET_PROTO_TCP) ? MACHNET_CTRL_OP_TCP_LISTEN + : MACHNET_CTRL_OP_LISTEN; req.listener_info.ip = ntohl(inet_addr(local_ip)); req.listener_info.port = local_port; diff --git a/src/ext/machnet.h b/src/ext/machnet.h index 831e2798..603e6e22 100644 --- a/src/ext/machnet.h +++ b/src/ext/machnet.h @@ -87,27 +87,30 @@ void *machnet_attach(); /** * @brief Listens for incoming messages on a specific IP and port. - * @param[in] channel The channel associated to the listener. - * @param[in] ip The local IP address to listen on. + * @param[in] channel_ctx The channel associated to the listener. + * @param[in] local_ip The local IP address to listen on. * @param[in] port The local port to listen on. + * @param[in] protocol MACHNET_PROTO_UDP (default) or MACHNET_PROTO_TCP. * @return 0 on success, -1 on failure. */ -int machnet_listen(void *channel_ctx, const char *local_ip, uint16_t port); +int machnet_listen(void *channel_ctx, const char *local_ip, uint16_t port, + int protocol); /** * @brief Creates a new connection to a remote peer. - * @param[in] channel The channel associated with the connection. + * @param[in] channel_ctx The channel associated with the connection. * @param[in] local_ip The local IP address. * @param[in] remote_ip The remote IP address. * @param[in] remote_port The remote port. * @param[out] flow A pointer to a `MachnetFlow_t` structure that will be * filled by the function upon success. + * @param[in] protocol MACHNET_PROTO_UDP (default) or MACHNET_PROTO_TCP. * @return 0 on success, -1 on failure. `flow` is filled with the flow * information on success. */ int machnet_connect(void *channel_ctx, const char *local_ip, const char *remote_ip, uint16_t remote_port, - MachnetFlow_t *flow); + MachnetFlow_t *flow, int protocol); /** * Enqueue one message for transmission to a remote peer over the network. diff --git a/src/ext/machnet_common.h b/src/ext/machnet_common.h index 7e649d6a..7058b2c7 100644 --- a/src/ext/machnet_common.h +++ b/src/ext/machnet_common.h @@ -74,10 +74,14 @@ typedef uint32_t MachnetRingSlot_t; static_assert(sizeof(MachnetRingSlot_t) % 4 == 0, "MachnetRingSlot_t must be 32-bit aligned"); +// Protocol selector for machnet_connect() and machnet_listen(). +#define MACHNET_PROTO_UDP 0 +#define MACHNET_PROTO_TCP 1 + // This is the abstraction of a network flow for the applications. It is -// equivalent to the 5-tuple, with just the protocol missing (UDP is always -// assumed). This structure is used to indicate the sender or receiver of a -// message (depending on the direction). Equivalent to `struct sockaddr_in'. +// equivalent to the 5-tuple. This structure is used to indicate the sender or +// receiver of a message (depending on the direction). Equivalent to +// `struct sockaddr_in'. struct MachnetFlow { uint32_t src_ip; uint32_t dst_ip; @@ -171,7 +175,10 @@ struct MachnetCtrlQueueEntry { #define MACHNET_CTRL_OP_CREATE_FLOW 0x0001 #define MACHNET_CTRL_OP_DESTROY_FLOW 0x0002 #define MACHNET_CTRL_OP_LISTEN 0x0003 -#define MACHNET_CTRL_OP_STATUS 0x0004; +#define MACHNET_CTRL_OP_STATUS 0x0004 +#define MACHNET_CTRL_OP_TCP_CREATE_FLOW 0x0011 +#define MACHNET_CTRL_OP_TCP_DESTROY_FLOW 0x0012 +#define MACHNET_CTRL_OP_TCP_LISTEN 0x0013 uint32_t opcode; #define MACHNET_CTRL_STATUS_OK 0x0000 #define MACHNET_CTRL_STATUS_ERROR 0x0001 diff --git a/src/include/channel.h b/src/include/channel.h index 05ee17fb..f0e93f20 100644 --- a/src/include/channel.h +++ b/src/include/channel.h @@ -31,7 +31,8 @@ class MachnetEngine; // forward declaration namespace juggler { namespace net { namespace flow { -class Flow; // forward declaration +class Flow; // forward declaration +class TcpFlow; // forward declaration } // namespace flow } // namespace net } // namespace juggler @@ -49,6 +50,7 @@ namespace shm { class ShmChannel { public: using Flow = juggler::net::flow::Flow; + using TcpFlow = juggler::net::flow::TcpFlow; using Listener = juggler::net::flow::Listener; ShmChannel() = delete; ShmChannel(const ShmChannel &) = delete; @@ -432,6 +434,14 @@ class Channel : public ShmChannel { */ std::list> &GetActiveFlows() { return active_flows_; } + /** + * @brief Gets the list of active TCP flows. + * @return A reference to the list of active TCP flows. + */ + std::list> &GetActiveTcpFlows() { + return active_tcp_flows_; + } + /** * @brief Gets the list of listeners associated with the channel. * @return A reference to the list of listeners. @@ -451,9 +461,25 @@ class Channel : public ShmChannel { return std::prev(active_flows_.end()); } + /** + * @brief Creates a new TCP flow associated with `this' channel object. + * @param params The parameters pack to be forwarded to the constructor of the + * TcpFlow. + * @return A const iterator to the newly created TCP flow. + */ + const std::list>::const_iterator CreateTcpFlow( + auto &&...params) { + active_tcp_flows_.emplace_back(std::make_unique( + std::forward(params)..., this)); + return std::prev(active_tcp_flows_.end()); + } + void RemoveFlow( const std::list>::const_iterator &flow_it); + void RemoveTcpFlow( + const std::list>::const_iterator &flow_it); + /** * @brief Adds a listener to the channel (i.e., an IP address and port pair). * @param params The parameters pack to be forwarded to the constructor of the @@ -489,6 +515,8 @@ class Channel : public ShmChannel { std::unordered_set listeners_; // List of active flows associated with this channel. std::list> active_flows_; + // List of active TCP flows associated with this channel. + std::list> active_tcp_flows_; // DPDK external memory region. rte_device *attached_dev_{nullptr}; diff --git a/src/include/machnet_engine.h b/src/include/machnet_engine.h index 80e4ebd8..24db494b 100644 --- a/src/include/machnet_engine.h +++ b/src/include/machnet_engine.h @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include #include @@ -342,8 +344,10 @@ class MachnetEngine { using Arp = net::Arp; using Ipv4 = net::Ipv4; using Udp = net::Udp; + using Tcp = net::Tcp; using Icmp = net::Icmp; using Flow = net::flow::Flow; + using TcpFlow = net::flow::TcpFlow; using PmdPort = juggler::dpdk::PmdPort; // Slow timer (periodic processing) interval in microseconds. const size_t kSlowTimerIntervalUs = 1000000; // 2ms @@ -382,6 +386,9 @@ class MachnetEngine { listeners_.emplace( ipv4_addr, std::unordered_map>()); + tcp_listeners_.emplace( + ipv4_addr, + std::unordered_map>()); } } @@ -497,7 +504,7 @@ class MachnetEngine { for (const auto &entry : shared_state_->GetArpTableEntries()) { s += "\t\t" + std::get<0>(entry) + " -> " + std::get<1>(entry) + "\n"; } - s += "\tListeners:\n"; + s += "\tListeners (UDP):\n"; for (const auto &listeners_for_ip : listeners_) { const auto ip = listeners_for_ip.first.ToString(); for (const auto &listener : listeners_for_ip.second) { @@ -505,12 +512,27 @@ class MachnetEngine { " <-> [" + listener.second->GetName() + "]" + "\n"; } } - s += "\tActive flows:\n"; + s += "\tListeners (TCP):\n"; + for (const auto &listeners_for_ip : tcp_listeners_) { + const auto ip = listeners_for_ip.first.ToString(); + for (const auto &listener : listeners_for_ip.second) { + s += "\t\tTCP " + ip + ":" + + std::to_string(listener.first.port.value()) + " <-> [" + + listener.second->GetName() + "]" + "\n"; + } + } + s += "\tActive flows (UDP):\n"; for (const auto &[key, flow_it] : active_flows_map_) { s += "\t\t"; s += (*flow_it)->ToString(); s += "\n"; } + s += "\tActive flows (TCP):\n"; + for (const auto &[key, flow_it] : active_tcp_flows_map_) { + s += "\t\t"; + s += (*flow_it)->ToString(); + s += "\n"; + } s += "\n"; LOG(INFO) << s; } @@ -560,25 +582,34 @@ class MachnetEngine { const auto &local_ip = ch_listener.addr; const auto &local_port = ch_listener.port; - if (listeners_.find(local_ip) == listeners_.end()) { - LOG(ERROR) << "No listeners for IP " << local_ip.ToString(); - continue; + // Try removing from UDP listeners. + if (listeners_.find(local_ip) != listeners_.end()) { + auto &listeners_for_ip = listeners_[local_ip]; + if (listeners_for_ip.find(local_port) != listeners_for_ip.end()) { + shared_state_->UnregisterListener(local_ip, local_port); + listeners_for_ip.erase(local_port); + continue; + } } - auto &listeners_for_ip = listeners_[local_ip]; - if (listeners_for_ip.find(local_port) == listeners_for_ip.end()) { - LOG(ERROR) << utils::Format("Listener not found %s:%hu", - local_ip.ToString().c_str(), - local_port.port.value()); - continue; + // Try removing from TCP listeners. + if (tcp_listeners_.find(local_ip) != tcp_listeners_.end()) { + auto &tcp_listeners_for_ip = tcp_listeners_[local_ip]; + if (tcp_listeners_for_ip.find(local_port) != + tcp_listeners_for_ip.end()) { + shared_state_->UnregisterListener(local_ip, local_port); + tcp_listeners_for_ip.erase(local_port); + continue; + } } - shared_state_->UnregisterListener(local_ip, local_port); - listeners_for_ip.erase(local_port); + LOG(ERROR) << utils::Format("Listener not found %s:%hu", + local_ip.ToString().c_str(), + local_port.port.value()); } const auto &channel_flows = channel->GetActiveFlows(); - // Remove from the engine's map all the flows associated with this + // Remove from the engine's map all the UDP flows associated with this // channel. for (const auto &flow : channel_flows) { if (active_flows_map_.find(flow->key()) != active_flows_map_.end()) { @@ -593,6 +624,22 @@ class MachnetEngine { } } + // Remove all TCP flows associated with this channel. + const auto &channel_tcp_flows = channel->GetActiveTcpFlows(); + for (const auto &tcp_flow : channel_tcp_flows) { + if (active_tcp_flows_map_.find(tcp_flow->key()) != + active_tcp_flows_map_.end()) { + shared_state_->SrcPortRelease(tcp_flow->key().local_addr, + tcp_flow->key().local_port); + LOG(INFO) << "Removing TCP flow " << tcp_flow->key().ToString(); + tcp_flow->ShutDown(); + active_tcp_flows_map_.erase(tcp_flow->key()); + } else { + LOG(WARNING) << "TCP Flow " << tcp_flow->key().ToString() + << " is not in the list of active TCP flows"; + } + } + // Finally remove the channel. channels_.erase(it); } @@ -678,6 +725,59 @@ class MachnetEngine { } // clang-format on break; + case MACHNET_CTRL_OP_TCP_CREATE_FLOW: + { + const Ipv4::Address src_addr(req.flow_info.src_ip); + if (!shared_state_->IsLocalIpv4Address(src_addr)) { + LOG(ERROR) << "Source IP " << src_addr.ToString() + << " is not local. Cannot create TCP flow."; + emit_completion(false); + break; + } + const Ipv4::Address dst_addr(req.flow_info.dst_ip); + const Tcp::Port dst_port(req.flow_info.dst_port); + LOG(INFO) << "Request to create TCP flow " + << src_addr.ToString() << " -> " + << dst_addr.ToString() << ":" + << dst_port.port.value(); + pending_tcp_requests_.emplace_back(periodic_ticks_, req, channel); + } + break; + case MACHNET_CTRL_OP_TCP_DESTROY_FLOW: + break; + case MACHNET_CTRL_OP_TCP_LISTEN: + { + const Ipv4::Address local_ip(req.listener_info.ip); + const Tcp::Port local_port(req.listener_info.port); + if (!shared_state_->IsLocalIpv4Address(local_ip) || + tcp_listeners_.find(local_ip) == tcp_listeners_.end()) { + emit_completion(false); + break; + } + + auto &listeners_on_ip = tcp_listeners_[local_ip]; + if (listeners_on_ip.find(local_port) != listeners_on_ip.end()) { + LOG(ERROR) << "Cannot register TCP listener for IP " + << local_ip.ToString() << " and port " + << local_port.port.value(); + emit_completion(false); + break; + } + + if (!shared_state_->RegisterListener(local_ip, local_port, + rxring_->GetRingId())) { + LOG(ERROR) << "Cannot register TCP listener for IP " + << local_ip.ToString() << " and port " + << local_port.port.value(); + emit_completion(false); + break; + } + + listeners_on_ip.emplace(local_port, channel); + channel->AddListener(local_ip, local_port); + emit_completion(true); + } + break; default: LOG(ERROR) << "Unknown control plane request opcode: " << req.opcode; @@ -777,6 +877,83 @@ class MachnetEngine { active_flows_map_.emplace((*flow_it)->key(), flow_it); it = pending_requests_.erase(it); } + + // Process pending TCP flow creation requests. + for (auto it = pending_tcp_requests_.begin(); + it != pending_tcp_requests_.end();) { + const auto &[timestamp_, req, channel] = *it; + if (periodic_ticks_ - timestamp_ > kPendingRequestTimeoutSlowTicks) { + LOG(ERROR) << utils::Format( + "TCP pending request timeout: [ID: %lu, Opcode: %u]", req.id, + req.opcode); + it = pending_tcp_requests_.erase(it); + continue; + } + + const Ipv4::Address src_addr(req.flow_info.src_ip); + const Ipv4::Address dst_addr(req.flow_info.dst_ip); + const Tcp::Port dst_port(req.flow_info.dst_port); + + auto remote_l2_addr = + shared_state_->GetL2Addr(txring_, src_addr, dst_addr); + if (!remote_l2_addr.has_value()) { + it++; + continue; + } + + // L2 address resolved. Allocate a source port (TCP uses the same port + // space — RSS hashing works the same way for TCP 4-tuples). + auto rss_lambda = [src_addr, dst_addr, dst_port, + rss_key = pmd_port_->GetRSSKey(), pmd_port = pmd_port_, + rx_queue_id = + rxring_->GetRingId()](uint16_t port) -> bool { + rte_thash_tuple reversed_ipv4_l3_l4_tuple; + reversed_ipv4_l3_l4_tuple.v4.src_addr = dst_addr.address.value(); + reversed_ipv4_l3_l4_tuple.v4.dst_addr = src_addr.address.value(); + reversed_ipv4_l3_l4_tuple.v4.sport = dst_port.port.value(); + reversed_ipv4_l3_l4_tuple.v4.dport = port; + + auto reversed_rss_hash = rte_softrss( + reinterpret_cast(&reversed_ipv4_l3_l4_tuple), + RTE_THASH_V4_L4_LEN, rss_key.data()); + if (pmd_port->GetRSSRxQueue(reversed_rss_hash) != rx_queue_id) + return false; + if (pmd_port->GetRSSRxQueue(__builtin_bswap32(reversed_rss_hash)) != + rx_queue_id) + return false; + return true; + }; + + auto src_port = shared_state_->SrcPortAlloc(src_addr, rss_lambda); + if (!src_port.has_value()) { + LOG(ERROR) << "Cannot allocate source port for TCP flow " + << src_addr.ToString(); + it = pending_tcp_requests_.erase(it); + continue; + } + + auto application_callback = + [req_id = req.id](shm::Channel *channel, bool success, + const juggler::net::flow::Key &flow_key) { + MachnetCtrlQueueEntry_t resp; + resp.id = req_id; + resp.opcode = MACHNET_CTRL_OP_STATUS; + resp.status = + success ? MACHNET_CTRL_STATUS_OK : MACHNET_CTRL_STATUS_ERROR; + resp.flow_info.src_ip = flow_key.local_addr.address.value(); + resp.flow_info.src_port = flow_key.local_port.port.value(); + resp.flow_info.dst_ip = flow_key.remote_addr.address.value(); + resp.flow_info.dst_port = flow_key.remote_port.port.value(); + channel->EnqueueCtrlCompletions(&resp, 1); + }; + const auto &flow_it = channel->CreateTcpFlow( + src_addr, src_port.value(), dst_addr, dst_port, + pmd_port_->GetL2Addr(), remote_l2_addr.value(), txring_, + application_callback); + (*flow_it)->InitiateHandshake(); + active_tcp_flows_map_.emplace((*flow_it)->key(), flow_it); + it = pending_tcp_requests_.erase(it); + } } /** @@ -798,6 +975,24 @@ class MachnetEngine { } ++it; } + + // Handle TCP flow retransmissions. + for (auto it = active_tcp_flows_map_.begin(); + it != active_tcp_flows_map_.end();) { + const auto &flow_it = it->second; + auto is_active = (*flow_it)->PeriodicCheck(); + if (!is_active) { + LOG(INFO) << "TCP Flow " << (*flow_it)->key().ToString() + << " is no longer active. Removing."; + auto channel = (*flow_it)->channel(); + shared_state_->SrcPortRelease((*flow_it)->key().local_addr, + (*flow_it)->key().local_port); + channel->RemoveTcpFlow(flow_it); + it = active_tcp_flows_map_.erase(it); + continue; + } + ++it; + } } /** @@ -841,24 +1036,29 @@ class MachnetEngine { const auto *eh = pkt->head_data(); const auto *ipv4h = pkt->head_data(sizeof(Ethernet)); - const auto *udph = pkt->head_data(sizeof(Ethernet) + sizeof(Ipv4)); - const net::flow::Key pkt_key(ipv4h->dst_addr, udph->dst_port, - ipv4h->src_addr, udph->src_port); // Check ivp4 header length. + // The packet may be padded to the Ethernet minimum frame size (60 bytes), + // so we allow pkt->length() >= the expected IP total length. + const size_t expected_len = sizeof(Ethernet) + ipv4h->total_length.value(); // clang-format off - if (pkt->length() != sizeof(Ethernet) + ipv4h->total_length.value()) [[unlikely]] { // NOLINT + if (pkt->length() < expected_len) [[unlikely]] { // NOLINT // clang-format on - LOG(WARNING) << "IPv4 packet length mismatch (expected: " + LOG(WARNING) << "IPv4 packet too short (expected: " << ipv4h->total_length.value() - << ", actual: " << pkt->length() << ")"; + << ", actual: " << pkt->length() - sizeof(Ethernet) << ")"; return; } switch (ipv4h->next_proto_id) { // clang-format off [[likely]] case Ipv4::kUdp: + { // clang-format on + const auto *udph = pkt->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + const net::flow::Key pkt_key(ipv4h->dst_addr, udph->dst_port, + ipv4h->src_addr, udph->src_port); + if (active_flows_map_.find(pkt_key) != active_flows_map_.end()) { const auto &flow_it = active_flows_map_[pkt_key]; (*flow_it)->InputPacket(pkt); @@ -909,6 +1109,64 @@ class MachnetEngine { } } + } + break; + + case Ipv4::kTcp: + { + if (pkt->length() < sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp)) + [[unlikely]] return; + + const auto *tcph = + pkt->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + const net::flow::Key pkt_key(ipv4h->dst_addr, tcph->dst_port, + ipv4h->src_addr, tcph->src_port); + + // Check active TCP flows. + if (active_tcp_flows_map_.find(pkt_key) != + active_tcp_flows_map_.end()) { + const auto &flow_it = active_tcp_flows_map_[pkt_key]; + (*flow_it)->InputPacket(pkt); + return; + } + + // Check TCP listeners for incoming SYN. + const auto &local_ipv4_addr = ipv4h->dst_addr; + const auto &local_tcp_port = tcph->dst_port; + if (tcp_listeners_.find(local_ipv4_addr) != tcp_listeners_.end()) { + const auto &listeners_on_ip = tcp_listeners_[local_ipv4_addr]; + if (listeners_on_ip.find(local_tcp_port) == listeners_on_ip.end()) { + LOG(INFO) << "TCP: No listener on port " + << local_tcp_port.port.value(); + return; + } + + // Only accept SYN packets on listening ports. + if (!(tcph->flags & Tcp::kSyn) || (tcph->flags & Tcp::kAck)) { + LOG(WARNING) + << "TCP: Received non-SYN packet on a listening port"; + return; + } + + const auto &channel = listeners_on_ip.at(local_tcp_port); + const auto &remote_ipv4_addr = ipv4h->src_addr; + const auto &remote_tcp_port = tcph->src_port; + + auto empty_callback = [](shm::Channel *, bool, + const net::flow::Key &) {}; + const auto &flow_it = channel->CreateTcpFlow( + local_ipv4_addr, local_tcp_port, remote_ipv4_addr, + remote_tcp_port, pmd_port_->GetL2Addr(), eh->src_addr, txring_, + empty_callback); + active_tcp_flows_map_.insert({pkt_key, flow_it}); + + // Set the flow to LISTEN state so the SYN is handled correctly. + (*flow_it)->StartPassiveOpen(); + + // Process the SYN packet on the new flow. + (*flow_it)->InputPacket(pkt); + } + } break; // clang-format off case Ipv4::kIcmp: @@ -985,16 +1243,33 @@ class MachnetEngine { const auto *flow_info = msg->flow(); const net::flow::Key msg_key(flow_info->src_ip, flow_info->src_port, flow_info->dst_ip, flow_info->dst_port); - if (active_flows_map_.find(msg_key) == active_flows_map_.end()) { - LOG(ERROR) << "Message received for a non-existing flow! " - << utils::Format("(Channel: %s, 5-tuple hash: %lu, Flow: %s)", - channel->GetName().c_str(), - std::hash{}(msg_key), - msg_key.ToString().c_str()); + + VLOG(1) << "process_msg: key=" << msg_key.ToString() + << " msg_len=" << msg->msg_length() + << " udp_flows=" << active_flows_map_.size() + << " tcp_flows=" << active_tcp_flows_map_.size(); + + // Check UDP flows first (most common path). + if (active_flows_map_.find(msg_key) != active_flows_map_.end()) { + const auto &flow_it = active_flows_map_[msg_key]; + (*flow_it)->OutputMessage(msg); return; } - const auto &flow_it = active_flows_map_[msg_key]; - (*flow_it)->OutputMessage(msg); + + // Check TCP flows. + if (active_tcp_flows_map_.find(msg_key) != active_tcp_flows_map_.end()) { + const auto &flow_it = active_tcp_flows_map_[msg_key]; + VLOG(1) << "process_msg: routing to TCP flow " + << (*flow_it)->key().ToString(); + (*flow_it)->OutputMessage(msg); + return; + } + + LOG(ERROR) << "Message received for a non-existing flow! " + << utils::Format("(Channel: %s, 5-tuple hash: %lu, Flow: %s)", + channel->GetName().c_str(), + std::hash{}(msg_key), + msg_key.ToString().c_str()); } private: @@ -1032,15 +1307,25 @@ class MachnetEngine { uint64_t last_periodic_timestamp_{0}; // Clock ticks for the slow timer. uint64_t periodic_ticks_{0}; - // Listeners for incoming packets. + // Listeners for incoming packets (UDP). std::unordered_map< Ipv4::Address, std::unordered_map>> listeners_{}; + // Listeners for incoming TCP connections. + std::unordered_map< + Ipv4::Address, + std::unordered_map>> + tcp_listeners_{}; // Unordered map of active flows. std::unordered_map>::const_iterator> active_flows_map_{}; + // Unordered map of active TCP flows. + std::unordered_map< + net::flow::Key, + const std::list>::const_iterator> + active_tcp_flows_map_{}; // Vector of channels to be added to the list of active channels. std::vector channels_to_enqueue_{}; // Vector of channels to be removed from the list of active channels. @@ -1049,6 +1334,10 @@ class MachnetEngine { std::list>> pending_requests_{}; + // List of pending TCP control plane requests. + std::list>> + pending_tcp_requests_{}; }; } // namespace juggler diff --git a/src/include/packet.h b/src/include/packet.h index b18e6bbb..00467576 100644 --- a/src/include/packet.h +++ b/src/include/packet.h @@ -101,6 +101,10 @@ class alignas(juggler::hardware_constructive_interference_size) Packet { offload_ipv4_csum(); mbuf_.ol_flags |= (RTE_MBUF_F_TX_UDP_CKSUM); } + void offload_tcpv4_csum() { + offload_ipv4_csum(); + mbuf_.ol_flags |= (RTE_MBUF_F_TX_TCP_CKSUM); + } /** * @brief Attach external buffer to this packet mbuf. diff --git a/src/include/tcp.h b/src/include/tcp.h new file mode 100644 index 00000000..8a6040bc --- /dev/null +++ b/src/include/tcp.h @@ -0,0 +1,66 @@ +/** + * @file tcp.h + * @brief TCP header definition for wire-format parsing/construction. + * + * This mirrors the structure of udp.h, providing a packed TCP header struct + * with big-endian typed fields for use in the Machnet userspace networking + * stack. + */ +#ifndef SRC_INCLUDE_TCP_H_ +#define SRC_INCLUDE_TCP_H_ + +#include +#include +#include + +#include +#include + +namespace juggler { +namespace net { + +/** + * @struct Tcp + * @brief Wire-format TCP header (20 bytes minimum, no options). + */ +struct __attribute__((packed)) Tcp { + // Reuse Udp::Port for TCP ports — same 2-byte big-endian representation. + using Port = Udp::Port; + + // TCP flag bits (in the flags byte). + enum Flags : uint8_t { + kFin = 0x01, + kSyn = 0x02, + kRst = 0x04, + kPsh = 0x08, + kAck = 0x10, + kUrg = 0x20, + }; + + Port src_port; + Port dst_port; + be32_t seq_num; ///< Sequence number. + be32_t ack_num; ///< Acknowledgement number. + uint8_t data_offset; ///< Upper 4 bits = data offset (in 32-bit words). + uint8_t flags; ///< TCP flags (lower 6 bits meaningful). + be16_t window; ///< Receive window size. + uint16_t checksum; ///< TCP checksum (or 0 if offloaded). + be16_t urgent_ptr; ///< Urgent pointer. + + /// @brief Returns the header length in bytes (data_offset field × 4). + uint8_t header_length() const { return (data_offset >> 4) * 4; } + + /// @brief Sets the data offset field for a given header length in bytes. + /// @param len Header length in bytes (must be a multiple of 4, >= 20). + void set_header_length(uint8_t len) { + data_offset = static_cast((len / 4) << 4); + } + + std::string ToString() const; +}; +static_assert(sizeof(Tcp) == 20, "TCP header must be 20 bytes"); + +} // namespace net +} // namespace juggler + +#endif // SRC_INCLUDE_TCP_H_ diff --git a/src/include/tcp_flow.h b/src/include/tcp_flow.h new file mode 100644 index 00000000..77266b67 --- /dev/null +++ b/src/include/tcp_flow.h @@ -0,0 +1,1092 @@ +/** + * @file tcp_flow.h + * @brief TCP flow implementation for Machnet. + * + * This provides a TCP-based transport path alongside the existing UDP-based + * Machnet protocol. A TcpFlow speaks standard TCP (3-way handshake, sequence + * numbers, ACKs, FIN) and translates between Machnet's message-based shared + * memory channel API and TCP byte streams. + * + * Key design decisions: + * - Reuses the same Channel / MsgBuf shared memory infrastructure. + * - Each TcpFlow is bound to one Channel, just like a UDP Flow. + * - Messages are framed on the wire with a 4-byte length prefix so that the + * receiver can reconstruct message boundaries from the TCP byte stream. + * - Uses the same flow::Key structure (the key is protocol-agnostic: IPs + + * ports). + * - The TcpFlow manages its own TCP state machine including connection + * establishment, data transfer (with simple sliding-window flow control), + * and teardown. + */ +#ifndef SRC_INCLUDE_TCP_FLOW_H_ +#define SRC_INCLUDE_TCP_FLOW_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace juggler { +namespace net { +namespace flow { + +/** + * @class TcpFlow + * @brief A TCP connection that interfaces with the Machnet shared-memory + * channel system. + * + * Mirrors net::flow::Flow but uses real TCP on the wire instead of + * UDP + MachnetPktHdr. Messages from the application are framed with a + * 4-byte network-order length prefix before being pushed into the TCP stream. + * Incoming TCP data is reassembled and de-framed before delivery to the app. + */ +class TcpFlow { + public: + using Ethernet = net::Ethernet; + using Ipv4 = net::Ipv4; + using Tcp = net::Tcp; + using ApplicationCallback = + std::function; + + /// TCP connection states (standard simplified). + enum class State { + kClosed, + kListen, + kSynSent, + kSynReceived, + kEstablished, + kFinWait1, + kFinWait2, + kCloseWait, + kLastAck, + kTimeWait, + }; + + static constexpr const char* StateToString(State state) { + switch (state) { + case State::kClosed: + return "CLOSED"; + case State::kListen: + return "LISTEN"; + case State::kSynSent: + return "SYN_SENT"; + case State::kSynReceived: + return "SYN_RECEIVED"; + case State::kEstablished: + return "ESTABLISHED"; + case State::kFinWait1: + return "FIN_WAIT_1"; + case State::kFinWait2: + return "FIN_WAIT_2"; + case State::kCloseWait: + return "CLOSE_WAIT"; + case State::kLastAck: + return "LAST_ACK"; + case State::kTimeWait: + return "TIME_WAIT"; + default: + return "UNKNOWN"; + } + } + + /// 4-byte message length prefix used to frame messages over TCP. + static constexpr size_t kMsgLenPrefixSize = 4; + + /// Maximum TCP segment payload (MSS). Conservative default. + static constexpr size_t kDefaultMSS = 1400; + + /// Size of the TCP MSS option (Kind=2, Length=4, Value=2 bytes). + static constexpr size_t kMSSOptionLen = 4; + + /// Initial TCP window size (in bytes). + static constexpr uint16_t kInitialWindowSize = 65535; + + /// Maximum number of retransmission attempts before giving up. + static constexpr uint32_t kMaxRetransmissions = 10; + + /// Initial RTO value in slow ticks (same units as PeriodicCheck calls). + static constexpr uint32_t kInitialRTO = 3; + + /// TIME_WAIT duration in slow ticks. + static constexpr uint32_t kTimeWaitTicks = 5; + + // ──────────────────────── Construction ──────────────────────── + + TcpFlow(const Ipv4::Address& local_addr, const Tcp::Port& local_port, + const Ipv4::Address& remote_addr, const Tcp::Port& remote_port, + const Ethernet::Address& local_l2_addr, + const Ethernet::Address& remote_l2_addr, dpdk::TxRing* txring, + ApplicationCallback callback, shm::Channel* channel) + : key_(local_addr, local_port, remote_addr, remote_port), + local_l2_addr_(local_l2_addr), + remote_l2_addr_(remote_l2_addr), + state_(State::kClosed), + txring_(CHECK_NOTNULL(txring)), + callback_(std::move(callback)), + channel_(CHECK_NOTNULL(channel)), + // TCP sequence / ack tracking + snd_una_(0), + snd_nxt_(0), + snd_isn_(GenerateISN()), + rcv_nxt_(0), + rcv_wnd_(kInitialWindowSize), + snd_wnd_(kInitialWindowSize), + rto_ticks_(kInitialRTO), + rto_remaining_(kInitialRTO), + rto_active_(false), + retransmit_count_(0), + // RX reassembly + rx_buf_offset_(0), + rx_pending_msg_len_(0), + rx_msg_train_head_(nullptr), + rx_msg_train_tail_(nullptr) { + CHECK_NOTNULL(txring_->GetPacketPool()); + snd_nxt_ = snd_isn_; + snd_una_ = snd_isn_; + } + + ~TcpFlow() = default; + + // ──────────────────── Accessors ──────────────────── + + const Key& key() const { return key_; } + shm::Channel* channel() const { return channel_; } + State state() const { return state_; } + + bool operator==(const TcpFlow& other) const { return key_ == other.key(); } + + std::string ToString() const { + return utils::Format( + "TCP %s [%s] <-> [%s] snd_una=%u snd_nxt=%u rcv_nxt=%u", + key_.ToString().c_str(), StateToString(state_), + channel_->GetName().c_str(), snd_una_, snd_nxt_, rcv_nxt_); + } + + bool Match(const dpdk::Packet* packet) const { + const auto* ih = packet->head_data(sizeof(Ethernet)); + const auto* tcph = + packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + return (ih->src_addr == key_.remote_addr && + ih->dst_addr == key_.local_addr && + tcph->src_port == key_.remote_port && + tcph->dst_port == key_.local_port); + } + + // ────────────────── Active Open (Client) ────────────────── + + void InitiateHandshake() { + CHECK(state_ == State::kClosed); + SendSyn(); + state_ = State::kSynSent; + rto_active_ = true; + rto_remaining_ = rto_ticks_; + } + + // ────────────────── Passive Open (Server) ────────────────── + + /** + * @brief Set the flow into LISTEN-like state for a passive open. + * + * Called by the engine when a SYN arrives on a listening port. The engine + * creates a new TcpFlow and calls StartPassiveOpen() *before* InputPacket() + * so that the SYN is correctly processed. + */ + void StartPassiveOpen() { + CHECK(state_ == State::kClosed); + state_ = State::kListen; + } + + // ────────────────── Shutdown ────────────────── + + void ShutDown() { + if (state_ == State::kEstablished || state_ == State::kCloseWait) { + SendFin(); + state_ = (state_ == State::kEstablished) ? State::kFinWait1 + : State::kLastAck; + } else { + SendRst(); + state_ = State::kClosed; + } + rto_active_ = false; + } + + // ────────────────── RX Path ────────────────── + + /** + * @brief Process an incoming TCP packet. + */ + void InputPacket(const dpdk::Packet* packet) { + const auto* ipv4h = packet->head_data(sizeof(Ethernet)); + const auto* tcph = + packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + const uint8_t tcp_hdr_len = tcph->header_length(); + + // Validate TCP header length (min 20, max 60 bytes). + if (tcp_hdr_len < sizeof(Tcp) || tcp_hdr_len > 60) [[unlikely]] { + LOG(WARNING) << "TCP: invalid header length " + << static_cast(tcp_hdr_len); + return; + } + + const size_t net_hdr_len = sizeof(Ethernet) + sizeof(Ipv4) + tcp_hdr_len; + const uint32_t seg_seq = tcph->seq_num.value(); + const uint32_t seg_ack = tcph->ack_num.value(); + const uint8_t flags = tcph->flags; + // Use IP total_length rather than packet->length() to compute the TCP + // payload size. Ethernet frames may be padded to the 60-byte minimum, + // so packet->length() can overcount by up to 6 bytes, corrupting the + // TCP reassembly / deframing state machine. + const size_t ip_total_len = ipv4h->total_length.value(); + if (ip_total_len < sizeof(Ipv4) + tcp_hdr_len) [[unlikely]] { + LOG(WARNING) << "TCP: IP total_length too small for TCP header"; + return; + } + const size_t payload_len = + ip_total_len - sizeof(Ipv4) - tcp_hdr_len; + + // ── RST handling (any state) ── + if (flags & Tcp::kRst) { + LOG(INFO) << "TCP RST received on " << key_.ToString(); + state_ = State::kClosed; + return; + } + + switch (state_) { + case State::kListen: + // Passive open: incoming SYN on a newly created flow. + // Parse TCP options from the kernel's SYN (MSS, etc.). + if (flags & Tcp::kSyn) { + ParseTcpOptions(tcph, tcp_hdr_len); + rcv_nxt_ = seg_seq + 1; // SYN consumes one seq. + SendSynAck(); + state_ = State::kSynReceived; + rto_active_ = true; + rto_remaining_ = rto_ticks_; + } + break; + case State::kSynSent: + HandleSynSent(tcph, seg_seq, seg_ack, flags); + break; + case State::kSynReceived: + HandleSynReceived(tcph, seg_seq, seg_ack, flags, packet, payload_len, + net_hdr_len); + break; + case State::kEstablished: + HandleEstablished(tcph, seg_seq, seg_ack, flags, packet, payload_len, + net_hdr_len); + break; + case State::kFinWait1: + HandleFinWait1(tcph, seg_seq, seg_ack, flags, packet, payload_len, + net_hdr_len); + break; + case State::kFinWait2: + HandleFinWait2(tcph, seg_seq, seg_ack, flags, packet, payload_len, + net_hdr_len); + break; + case State::kCloseWait: + // We already received FIN; waiting for app to close. + if (flags & Tcp::kAck) { + AdvanceSndUna(seg_ack); + } + break; + case State::kLastAck: + if (flags & Tcp::kAck) { + AdvanceSndUna(seg_ack); + state_ = State::kClosed; + } + break; + case State::kTimeWait: + // Absorb duplicates; stay in TIME_WAIT. + break; + case State::kClosed: + // Stale packet on closed connection; ignore or send RST. + if (!(flags & Tcp::kRst)) { + SendRst(); + } + break; + default: + break; + } + } + + // ────────────────── TX Path ────────────────── + + /** + * @brief Send an application message over this TCP flow. + * + * The message is framed with a 4-byte length prefix, then segmented into + * TCP-sized packets and transmitted. + */ + void OutputMessage(shm::MsgBuf* msg) { + if (state_ != State::kEstablished && state_ != State::kCloseWait) { + LOG(ERROR) << "Cannot send on TCP flow in state " << StateToString(state_); + return; + } + + VLOG(1) << "TCP OutputMessage: " << key_.ToString() + << " msg_len=" << msg->msg_length() + << " snd_nxt=" << snd_nxt_ << " snd_una=" << snd_una_; + + // Gather the full message payload from the MsgBuf train. + // We copy the payload into a contiguous buffer to simplify TCP + // segmentation. For a zero-copy path this could be optimized later. + const uint32_t msg_len = msg->msg_length(); + const uint32_t total_len = kMsgLenPrefixSize + msg_len; + + // We'll transmit the length prefix + payload as a TCP byte stream. + // Segment into MSS-sized TCP packets. + uint32_t bytes_sent = 0; + + // First, send the 4-byte length prefix. + uint8_t len_buf[kMsgLenPrefixSize]; + uint32_t net_len = htobe32(msg_len); + std::memcpy(len_buf, &net_len, kMsgLenPrefixSize); + + // Walk through the MsgBuf chain and send data. + // We interleave the length prefix with the first payload chunk. + auto* cur_buf = msg; + size_t prefix_remaining = kMsgLenPrefixSize; + size_t buf_offset = 0; // Offset within the current MsgBuf. + bool first_segment = true; + + while (bytes_sent < total_len) { + auto* packet = txring_->GetPacketPool()->PacketAlloc(); + if (packet == nullptr) { + LOG(ERROR) << "Failed to allocate packet for TCP TX"; + return; + } + dpdk::Packet::Reset(packet); + + const size_t hdr_len = sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp); + size_t payload_room = peer_mss_; + size_t pkt_payload_len = 0; + + // Allocate space for headers + max payload. + size_t max_pkt_len = + hdr_len + std::min(payload_room, total_len - bytes_sent); + CHECK_NOTNULL(packet->append(static_cast(max_pkt_len))); + + uint8_t* payload_dst = + packet->head_data(static_cast(hdr_len)); + + // Copy length prefix into the first segment(s). + if (prefix_remaining > 0) { + size_t prefix_to_copy = std::min(prefix_remaining, payload_room); + std::memcpy(payload_dst, + len_buf + (kMsgLenPrefixSize - prefix_remaining), + prefix_to_copy); + payload_dst += prefix_to_copy; + pkt_payload_len += prefix_to_copy; + prefix_remaining -= prefix_to_copy; + payload_room -= prefix_to_copy; + } + + // Copy message payload from MsgBuf chain. + while (payload_room > 0 && cur_buf != nullptr) { + const size_t avail = cur_buf->length() - buf_offset; + const size_t to_copy = std::min(avail, payload_room); + std::memcpy(payload_dst, + static_cast(cur_buf->head_data()) + buf_offset, + to_copy); + payload_dst += to_copy; + pkt_payload_len += to_copy; + payload_room -= to_copy; + buf_offset += to_copy; + + if (buf_offset == cur_buf->length()) { + // Fully consumed this MsgBuf; move to next in chain. + buf_offset = 0; + if (cur_buf->is_sg() || cur_buf->has_chain()) { + cur_buf = channel_->GetMsgBuf(cur_buf->next()); + } else { + cur_buf = nullptr; + } + } + } + + // Adjust actual packet length if we didn't fill the max. + // We already appended max_pkt_len; trim if needed. + // Actually, since we append the exact max and copy into it, the packet + // length is already correct if max_pkt_len was right. For safety: + // (packet length is set by append) + + // Prepare headers. + uint8_t tcp_flags = Tcp::kAck; + if (first_segment) { + tcp_flags |= Tcp::kPsh; // Push on first segment for low latency. + first_segment = false; + } + + PrepareL2Header(packet); + PrepareL3Header(packet); + PrepareL4Header(packet, snd_nxt_, rcv_nxt_, tcp_flags); + packet->offload_tcpv4_csum(); + + snd_nxt_ += pkt_payload_len; + bytes_sent += pkt_payload_len; + + txring_->SendPackets(&packet, 1); + } + + if (!rto_active_) { + rto_active_ = true; + rto_remaining_ = rto_ticks_; + } + + // Free the MsgBuf chain (the engine normally tracks this, but since TCP + // does its own segmentation, we consume the buffers here). + FreeMsgBufChain(msg); + } + + // ────────────────── Periodic Check ────────────────── + + /** + * @brief Called periodically by the engine to handle retransmissions and + * time-waits. + * @return false if the flow should be removed. + */ + bool PeriodicCheck() { + if (state_ == State::kClosed) return false; + if (state_ == State::kTimeWait) { + if (time_wait_remaining_ > 0) { + time_wait_remaining_--; + return true; + } + state_ = State::kClosed; + return false; + } + + if (!rto_active_) return true; + + if (rto_remaining_ > 0) { + rto_remaining_--; + return true; + } + + // RTO expired. + retransmit_count_++; + if (retransmit_count_ > kMaxRetransmissions) { + LOG(ERROR) << "TCP max retransmissions reached on " << key_.ToString(); + if (state_ == State::kSynSent) { + callback_(channel_, false, key_); + } + state_ = State::kClosed; + return false; + } + + // Retransmit based on state. + switch (state_) { + case State::kSynSent: + LOG(INFO) << "TCP retransmitting SYN"; + SendSyn(); + break; + case State::kSynReceived: + LOG(INFO) << "TCP retransmitting SYN-ACK"; + SendSynAck(); + break; + default: + // For established connections, a full retransmission mechanism + // would require buffering sent data. For now, we just reset the timer. + break; + } + rto_remaining_ = rto_ticks_; + return true; + } + + private: + // ──────────────── Helpers: Header Preparation ──────────────── + + void PrepareL2Header(dpdk::Packet* packet) const { + auto* eh = packet->head_data(); + eh->src_addr = local_l2_addr_; + eh->dst_addr = remote_l2_addr_; + eh->eth_type = be16_t(Ethernet::kIpv4); + packet->set_l2_len(sizeof(*eh)); + } + + void PrepareL3Header(dpdk::Packet* packet) const { + auto* ipv4h = packet->head_data(sizeof(Ethernet)); + ipv4h->version_ihl = 0x45; + ipv4h->type_of_service = 0; + ipv4h->packet_id = be16_t(0x1513); + ipv4h->fragment_offset = be16_t(0x4000); // Don't Fragment. + ipv4h->time_to_live = 64; + ipv4h->next_proto_id = Ipv4::Proto::kTcp; + ipv4h->total_length = be16_t(packet->length() - sizeof(Ethernet)); + ipv4h->src_addr = key_.local_addr; + ipv4h->dst_addr = key_.remote_addr; + ipv4h->hdr_checksum = 0; + packet->set_l3_len(sizeof(*ipv4h)); + } + + void PrepareL4Header(dpdk::Packet* packet, uint32_t seq, uint32_t ack, + uint8_t flags) const { + auto* tcph = packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + tcph->src_port = key_.local_port; + tcph->dst_port = key_.remote_port; + tcph->seq_num = be32_t(seq); + tcph->ack_num = be32_t(ack); + tcph->set_header_length(sizeof(Tcp)); // 20 bytes, no options. + tcph->flags = flags; + tcph->window = be16_t(rcv_wnd_); + tcph->checksum = 0; + tcph->urgent_ptr = be16_t(0); + } + + // ──────────────── Helpers: Send Control Packets ──────────────── + + void SendControlPacket(uint32_t seq, uint32_t ack, uint8_t flags) { + auto* packet = CHECK_NOTNULL(txring_->GetPacketPool()->PacketAlloc()); + dpdk::Packet::Reset(packet); + + const size_t pkt_len = sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp); + CHECK_NOTNULL(packet->append(static_cast(pkt_len))); + + PrepareL2Header(packet); + PrepareL3Header(packet); + PrepareL4Header(packet, seq, ack, flags); + packet->offload_tcpv4_csum(); + + txring_->SendPackets(&packet, 1); + } + + /// Send a control packet with the MSS option appended (for SYN/SYN-ACK). + /// Linux kernel expects an MSS option; without it, it defaults to 536 bytes. + void SendControlPacketWithMSS(uint32_t seq, uint32_t ack, uint8_t flags, + uint16_t mss) { + auto* packet = CHECK_NOTNULL(txring_->GetPacketPool()->PacketAlloc()); + dpdk::Packet::Reset(packet); + + const size_t tcp_hdr_with_opts = sizeof(Tcp) + kMSSOptionLen; + const size_t pkt_len = sizeof(Ethernet) + sizeof(Ipv4) + tcp_hdr_with_opts; + CHECK_NOTNULL(packet->append(static_cast(pkt_len))); + + PrepareL2Header(packet); + PrepareL3Header(packet); + PrepareL4Header(packet, seq, ack, flags); + + // Override TCP header length to include the MSS option. + auto* tcph = packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + tcph->set_header_length(static_cast(tcp_hdr_with_opts)); + + // Write MSS option: Kind=2, Length=4, Value=MSS (big-endian). + uint8_t* opts = reinterpret_cast(tcph) + sizeof(Tcp); + opts[0] = 2; // Kind: Maximum Segment Size + opts[1] = 4; // Length + uint16_t mss_net = htobe16(mss); + std::memcpy(&opts[2], &mss_net, sizeof(mss_net)); + + packet->offload_tcpv4_csum(); + txring_->SendPackets(&packet, 1); + } + + void SendSyn() { + SendControlPacketWithMSS(snd_nxt_, 0, Tcp::kSyn, + static_cast(kDefaultMSS)); + snd_nxt_++; // SYN consumes one sequence number. + } + + void SendSynAck() { + SendControlPacketWithMSS(snd_isn_, rcv_nxt_, Tcp::kSyn | Tcp::kAck, + static_cast(kDefaultMSS)); + snd_nxt_ = snd_isn_ + 1; // SYN-ACK consumes one sequence number. + } + + void SendAck() { SendControlPacket(snd_nxt_, rcv_nxt_, Tcp::kAck); } + + void SendFin() { + SendControlPacket(snd_nxt_, rcv_nxt_, Tcp::kFin | Tcp::kAck); + snd_nxt_++; // FIN consumes one sequence number. + } + + void SendRst() { SendControlPacket(snd_nxt_, 0, Tcp::kRst); } + + // ──────────────── Helpers: TCP Option Parsing ────────────────── + + /// Parse TCP options from a received header. Currently extracts MSS. + /// Linux kernel SYN/SYN-ACK includes MSS, Window Scale, SACK-Permitted, + /// and Timestamps. We parse MSS and ignore the rest (since we don't + /// negotiate window scaling, the kernel won't apply it). + void ParseTcpOptions(const Tcp* tcph, uint8_t hdr_len) { + if (hdr_len <= sizeof(Tcp)) return; // No options. + const uint8_t* opts = + reinterpret_cast(tcph) + sizeof(Tcp); + const size_t opts_len = hdr_len - sizeof(Tcp); + size_t i = 0; + while (i < opts_len) { + uint8_t kind = opts[i]; + if (kind == 0) break; // End of Option List. + if (kind == 1) { i++; continue; } // NOP padding. + if (i + 1 >= opts_len) break; + uint8_t opt_len = opts[i + 1]; + if (opt_len < 2 || i + opt_len > opts_len) break; // Malformed. + if (kind == 2 && opt_len == 4) { + // MSS option. + uint16_t mss_net; + std::memcpy(&mss_net, opts + i + 2, sizeof(mss_net)); + peer_mss_ = be16toh(mss_net); + if (peer_mss_ == 0) peer_mss_ = static_cast(kDefaultMSS); + VLOG(1) << "TCP: parsed peer MSS=" << peer_mss_; + } + // Window Scale (kind=3), SACK-Permitted (kind=4), Timestamps (kind=8): + // intentionally ignored — we don't negotiate these options. + i += opt_len; + } + } + + // ──────────────── Helpers: In-order Payload with Overlap ────────────────── + + /** + * @brief Process incoming payload, handling partial retransmission overlaps. + * + * The Linux kernel retransmits aggressively, and a retransmitted segment + * may partially overlap data we already received. This helper skips the + * already-received prefix and delivers only new bytes to ConsumePayload. + * + * @return true if new data was consumed (caller should ACK). + */ + bool ProcessInOrderPayload(const dpdk::Packet* packet, uint32_t seg_seq, + size_t payload_len, size_t net_hdr_len) { + if (payload_len == 0) return false; + + const uint8_t* base_payload = + packet->head_data( + static_cast(net_hdr_len)); + + if (seg_seq == rcv_nxt_) { + // Perfect in-order delivery. + ConsumePayload(base_payload, payload_len); + rcv_nxt_ += static_cast(payload_len); + return true; + } + + // Check for retransmission that partially overlaps new data. + uint32_t seg_end = seg_seq + static_cast(payload_len); + if (SeqLeq(seg_seq, rcv_nxt_) && SeqGt(seg_end, rcv_nxt_)) { + uint32_t overlap = rcv_nxt_ - seg_seq; // Works with wrapping. + size_t new_len = payload_len - overlap; + ConsumePayload(base_payload + overlap, new_len); + rcv_nxt_ += static_cast(new_len); + return true; + } + + // Pure duplicate (seg_end <= rcv_nxt_) or out-of-order gap. + return false; + } + + // ──────────────── State Machine Handlers ──────────────── + + void HandleSynSent(const Tcp* tcph, uint32_t seg_seq, uint32_t seg_ack, + uint8_t flags) { + if ((flags & Tcp::kSyn) && (flags & Tcp::kAck)) { + // SYN-ACK received from Linux kernel. + if (seg_ack != snd_nxt_) { + LOG(ERROR) << "TCP SYN-ACK with wrong ack: " << seg_ack + << " expected " << snd_nxt_; + return; + } + rcv_nxt_ = seg_seq + 1; // SYN consumes one seq. + snd_una_ = seg_ack; + snd_wnd_ = tcph->window.value(); + + // Parse TCP options from the kernel's SYN-ACK (MSS, etc.). + ParseTcpOptions(tcph, tcph->header_length()); + + // Send ACK to complete 3-way handshake. + SendAck(); + state_ = State::kEstablished; + rto_active_ = false; + retransmit_count_ = 0; + + // Notify application. + callback_(channel_, true, key_); + } else if (flags & Tcp::kSyn) { + // Simultaneous open: SYN without ACK. + ParseTcpOptions(tcph, tcph->header_length()); + rcv_nxt_ = seg_seq + 1; + SendSynAck(); + state_ = State::kSynReceived; + } + } + + void HandleSynReceived(const Tcp* tcph, uint32_t seg_seq, + uint32_t seg_ack, uint8_t flags, + const dpdk::Packet* packet, size_t payload_len, + size_t net_hdr_len) { + if (flags & Tcp::kAck) { + if (seg_ack == snd_nxt_) { + state_ = State::kEstablished; + snd_una_ = seg_ack; + snd_wnd_ = tcph->window.value(); + rto_active_ = false; + retransmit_count_ = 0; + + // Linux kernel can piggyback data on the completing handshake ACK. + if (ProcessInOrderPayload(packet, seg_seq, payload_len, net_hdr_len)) { + SendAck(); + } + } + } + } + + void HandleEstablished(const Tcp* tcph, uint32_t seg_seq, uint32_t seg_ack, + uint8_t flags, const dpdk::Packet* packet, + size_t payload_len, size_t net_hdr_len) { + VLOG(1) << "TCP HandleEstablished: " << key_.ToString() + << " seq=" << seg_seq << " ack=" << seg_ack + << " flags=0x" << std::hex << static_cast(flags) << std::dec + << " payload_len=" << payload_len; + + // Handle retransmitted SYN(-ACK) from the kernel — it missed our + // final handshake ACK. Re-send the ACK so the kernel can proceed. + if (flags & Tcp::kSyn) { + SendAck(); + return; + } + + // Process ACK. + if (flags & Tcp::kAck) { + AdvanceSndUna(seg_ack); + snd_wnd_ = tcph->window.value(); + } + + // Process payload data with overlap handling for kernel retransmissions. + if (payload_len > 0) { + if (ProcessInOrderPayload(packet, seg_seq, payload_len, net_hdr_len)) { + SendAck(); + } else { + // Duplicate or out-of-order — send dup ACK to trigger fast retransmit. + SendAck(); + } + } + + // FIN handling — only accept if the FIN is at the expected sequence. + if (flags & Tcp::kFin) { + uint32_t fin_seq = seg_seq + static_cast(payload_len); + if (fin_seq == rcv_nxt_) { + rcv_nxt_++; // FIN consumes one sequence number. + SendAck(); + state_ = State::kCloseWait; + } else if (SeqLt(fin_seq, rcv_nxt_)) { + // Retransmitted FIN we already processed — re-ACK. + SendAck(); + } + // fin_seq > rcv_nxt_: gap ahead of FIN; ignore for now, kernel will + // retransmit the missing data. + } + } + + void HandleFinWait1(const Tcp* tcph, uint32_t seg_seq, uint32_t seg_ack, + uint8_t flags, const dpdk::Packet* packet, + size_t payload_len, size_t net_hdr_len) { + if (flags & Tcp::kAck) { + AdvanceSndUna(seg_ack); + } + + // Process incoming data with overlap handling. + if (payload_len > 0) { + if (ProcessInOrderPayload(packet, seg_seq, payload_len, net_hdr_len)) { + SendAck(); + } else { + SendAck(); + } + } + + bool our_fin_acked = (snd_una_ == snd_nxt_); + + if (flags & Tcp::kFin) { + uint32_t fin_seq = seg_seq + static_cast(payload_len); + if (fin_seq == rcv_nxt_) { + rcv_nxt_++; + } + SendAck(); + state_ = State::kTimeWait; + time_wait_remaining_ = kTimeWaitTicks; + } else if (our_fin_acked) { + state_ = State::kFinWait2; + } + } + + void HandleFinWait2(const Tcp* /*tcph*/, uint32_t seg_seq, + uint32_t /*seg_ack*/, uint8_t flags, + const dpdk::Packet* packet, size_t payload_len, + size_t net_hdr_len) { + // Process incoming data with overlap handling. + if (payload_len > 0) { + if (ProcessInOrderPayload(packet, seg_seq, payload_len, net_hdr_len)) { + SendAck(); + } else { + SendAck(); + } + } + + if (flags & Tcp::kFin) { + uint32_t fin_seq = seg_seq + static_cast(payload_len); + if (fin_seq == rcv_nxt_) { + rcv_nxt_++; + } + SendAck(); + state_ = State::kTimeWait; + time_wait_remaining_ = kTimeWaitTicks; + } + } + + // ──────────────── RX Reassembly / Deframing ──────────────── + + /** + * @brief Consume incoming TCP payload bytes and reassemble framed messages. + * + * Messages on the wire are preceded by a 4-byte big-endian length prefix. + * This function accumulates bytes and delivers complete messages to the + * channel for the application to consume. + */ + void ConsumePayload(const uint8_t* data, size_t len) { + size_t offset = 0; + while (offset < len) { + // Phase 1: Read the message length prefix if we haven't yet. + if (rx_pending_msg_len_ == 0) { + // We need 4 bytes for the length prefix. + while (rx_len_buf_offset_ < kMsgLenPrefixSize && offset < len) { + rx_len_buf_[rx_len_buf_offset_++] = data[offset++]; + } + if (rx_len_buf_offset_ < kMsgLenPrefixSize) { + return; // Need more data for the length prefix. + } + uint32_t net_msg_len; + std::memcpy(&net_msg_len, rx_len_buf_, kMsgLenPrefixSize); + rx_pending_msg_len_ = be32toh(net_msg_len); + rx_buf_offset_ = 0; + rx_len_buf_offset_ = 0; + + if (rx_pending_msg_len_ == 0 || + rx_pending_msg_len_ > MACHNET_MSG_MAX_LEN) { + LOG(ERROR) << "Invalid TCP message length: " << rx_pending_msg_len_; + rx_pending_msg_len_ = 0; + return; + } + } + + // Phase 2: Copy payload bytes into MsgBuf(s). + size_t remaining_for_msg = rx_pending_msg_len_ - rx_buf_offset_; + size_t available = len - offset; + size_t to_consume = std::min(remaining_for_msg, available); + + // Allocate MsgBufs and copy data. + size_t consumed = 0; + while (consumed < to_consume) { + // Need a new MsgBuf? + if (rx_cur_msgbuf_ == nullptr) { + rx_cur_msgbuf_ = channel_->MsgBufAlloc(); + if (rx_cur_msgbuf_ == nullptr) { + LOG(ERROR) << "TCP RX: Failed to allocate MsgBuf. Dropping data."; + // Reset state for this message. + rx_pending_msg_len_ = 0; + rx_buf_offset_ = 0; + rx_cur_msgbuf_ = nullptr; + // Free any partial train. + if (rx_msg_train_head_ != nullptr) { + FreeMsgBufChain(rx_msg_train_head_); + rx_msg_train_head_ = nullptr; + rx_msg_train_tail_ = nullptr; + } + return; + } + // Set up the msgbuf. + bool is_first = (rx_msg_train_head_ == nullptr); + if (is_first) { + rx_cur_msgbuf_->set_flags(MACHNET_MSGBUF_FLAGS_SYN); + rx_cur_msgbuf_->set_msg_length(rx_pending_msg_len_); + rx_cur_msgbuf_->set_src_ip(key_.remote_addr.address.value()); + rx_cur_msgbuf_->set_src_port(key_.remote_port.port.value()); + rx_cur_msgbuf_->set_dst_ip(key_.local_addr.address.value()); + rx_cur_msgbuf_->set_dst_port(key_.local_port.port.value()); + rx_msg_train_head_ = rx_cur_msgbuf_; + rx_msg_train_tail_ = rx_cur_msgbuf_; + } else { + rx_cur_msgbuf_->set_flags(MACHNET_MSGBUF_FLAGS_SG); + rx_msg_train_tail_->set_next(rx_cur_msgbuf_); + rx_msg_train_tail_ = rx_cur_msgbuf_; + } + } + + size_t buf_room = channel_->GetUsableBufSize() - rx_cur_msgbuf_->length(); + if (buf_room == 0) { + // Current MsgBuf is full. Mark as SG and get a new one. + rx_cur_msgbuf_->set_flags(rx_cur_msgbuf_->flags() | + MACHNET_MSGBUF_FLAGS_SG); + rx_cur_msgbuf_ = nullptr; + continue; + } + size_t chunk = std::min(to_consume - consumed, buf_room); + auto* dst = rx_cur_msgbuf_->append(static_cast(chunk)); + std::memcpy(dst, data + offset + consumed, chunk); + consumed += chunk; + } + + offset += consumed; + rx_buf_offset_ += consumed; + + // Check if message is complete. + if (rx_buf_offset_ >= rx_pending_msg_len_) { + // Mark the last MsgBuf. + if (rx_cur_msgbuf_ != nullptr) { + // Clear SG flag and set FIN on the last buffer. + uint8_t f = rx_cur_msgbuf_->flags(); + f &= ~MACHNET_MSGBUF_FLAGS_SG; + f |= MACHNET_MSGBUF_FLAGS_FIN; + rx_cur_msgbuf_->set_flags(f); + } + + // Deliver to application. + if (rx_msg_train_head_ != nullptr) { + auto nr = channel_->EnqueueMessages(&rx_msg_train_head_, 1); + if (nr != 1) { + LOG(ERROR) + << "TCP: Failed to deliver message to channel. Dropping."; + FreeMsgBufChain(rx_msg_train_head_); + } + } + + // Reset for next message. + rx_pending_msg_len_ = 0; + rx_buf_offset_ = 0; + rx_cur_msgbuf_ = nullptr; + rx_msg_train_head_ = nullptr; + rx_msg_train_tail_ = nullptr; + } + } + } + + // ──────────────── Helpers ──────────────── + + void AdvanceSndUna(uint32_t ack) { + if (SeqGt(ack, snd_una_) && SeqLeq(ack, snd_nxt_)) { + snd_una_ = ack; + retransmit_count_ = 0; + if (snd_una_ == snd_nxt_) { + rto_active_ = false; // All data acknowledged. + } else { + rto_remaining_ = rto_ticks_; + } + } + } + + void FreeMsgBufChain(shm::MsgBuf* head) { + shm::MsgBufBatch to_free; + auto* cur = head; + while (cur != nullptr) { + shm::MsgBuf* next_buf = nullptr; + if (cur->is_sg() || cur->has_chain()) { + next_buf = channel_->GetMsgBuf(cur->next()); + } + to_free.Append(cur, cur->index()); + if (to_free.IsFull()) { + channel_->MsgBufBulkFree(&to_free); + } + cur = next_buf; + } + if (to_free.GetSize() > 0) { + channel_->MsgBufBulkFree(&to_free); + } + } + + /// @brief Generate a pseudo-random initial sequence number. + static uint32_t GenerateISN() { + // Use a simple timestamp-based ISN. In production this should be more + // robust (RFC 6528), but for a userspace stack this is sufficient. + return static_cast(__builtin_ia32_rdtsc() & 0xFFFFFFFF); + } + + // TCP sequence number comparison helpers (handles wrapping). + static bool SeqLt(uint32_t a, uint32_t b) { + return static_cast(a - b) < 0; + } + static bool SeqLeq(uint32_t a, uint32_t b) { + return static_cast(a - b) <= 0; + } + static bool SeqGt(uint32_t a, uint32_t b) { + return static_cast(a - b) > 0; + } + static bool SeqGeq(uint32_t a, uint32_t b) { + return static_cast(a - b) >= 0; + } + + // ──────────────── Data Members ──────────────── + + const Key key_; + const Ethernet::Address local_l2_addr_; + const Ethernet::Address remote_l2_addr_; + State state_; + dpdk::TxRing* txring_; + ApplicationCallback callback_; + shm::Channel* channel_; + + // TCP send-side state. + uint32_t snd_una_; ///< Oldest unacknowledged sequence number. + uint32_t snd_nxt_; ///< Next sequence number to send. + uint32_t snd_isn_; ///< Initial send sequence number. + + // TCP receive-side state. + uint32_t rcv_nxt_; ///< Next expected receive sequence number. + uint16_t rcv_wnd_; ///< Receive window (advertised to peer). + uint16_t snd_wnd_; ///< Send window (from peer). + + /// Peer's MSS learned from TCP options in SYN/SYN-ACK. + /// If the peer (Linux kernel) doesn't send an MSS option we fall back to + /// kDefaultMSS. This is used in OutputMessage for segmentation. + uint16_t peer_mss_{static_cast(kDefaultMSS)}; + + // Retransmission timer (in periodic tick units). + uint32_t rto_ticks_; + uint32_t rto_remaining_; + bool rto_active_; + uint32_t retransmit_count_; + + /// TIME_WAIT countdown (in periodic tick units). + uint32_t time_wait_remaining_{kTimeWaitTicks}; + + // RX reassembly state for message deframing. + uint8_t rx_len_buf_[kMsgLenPrefixSize]{}; ///< Partial length prefix buffer. + uint8_t rx_len_buf_offset_{0}; + uint32_t rx_buf_offset_; ///< Bytes received for current message. + uint32_t rx_pending_msg_len_; ///< Expected length of current message. + shm::MsgBuf* rx_cur_msgbuf_{nullptr}; ///< Current MsgBuf being filled. + shm::MsgBuf* rx_msg_train_head_; ///< Head of current message train. + shm::MsgBuf* rx_msg_train_tail_; ///< Tail of current message train. +}; + +} // namespace flow +} // namespace net +} // namespace juggler + +namespace std { + +template <> +struct hash { + size_t operator()(const juggler::net::flow::TcpFlow& flow) const { + const auto& key = flow.key(); + return juggler::utils::hash(reinterpret_cast(&key), + sizeof(key)); + } +}; + +} // namespace std + +#endif // SRC_INCLUDE_TCP_FLOW_H_