From d68e957534c84095a33482ce0718c507e9c3519e Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Thu, 23 Oct 2025 09:11:32 +0200 Subject: [PATCH 1/2] [icxtunnel] implement batched io --- go.mod | 2 +- go.sum | 4 +- pkg/cmd/alpha/tunnel_relay.go | 8 +- pkg/cmd/alpha/tunnel_run.go | 9 +- pkg/netstack/icx_network.go | 193 ++++++++----- pkg/netstack/icx_network_test.go | 11 +- pkg/tunnel/batchpc/batchpc.go | 224 +++++++++++++++ pkg/tunnel/batchpc/batchpc_test.go | 227 +++++++++++++++ pkg/tunnel/bifurcate/bifurcate.go | 131 +++++---- pkg/tunnel/bifurcate/bifurcate_test.go | 351 +++++++++++++++++------ pkg/tunnel/bifurcate/chanpc.go | 135 +++++++-- pkg/tunnel/l2pc/l2pc.go | 251 ++++++++++++----- pkg/tunnel/l2pc/l2pc_test.go | 374 +++++++++++++++++++------ pkg/tunnel/router/icx_netlink_linux.go | 2 +- pkg/tunnel/router/options.go | 8 +- 15 files changed, 1536 insertions(+), 394 deletions(-) create mode 100644 pkg/tunnel/batchpc/batchpc.go create mode 100644 pkg/tunnel/batchpc/batchpc_test.go diff --git a/go.mod b/go.mod index 638fcf55..41ee6fda 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/adrg/xdg v0.5.3 github.com/alphadose/haxmap v1.4.1 github.com/anatol/vmtest v0.0.0-20250318022921-2f32244e2f0f - github.com/apoxy-dev/icx v0.12.1 + github.com/apoxy-dev/icx v0.14.0 github.com/avast/retry-go/v4 v4.6.1 github.com/bramvdbogaerde/go-scp v1.5.0 github.com/buraksezer/olric v0.5.6 diff --git a/go.sum b/go.sum index 68416e53..adbe702c 100644 --- a/go.sum +++ b/go.sum @@ -123,8 +123,8 @@ github.com/apoxy-dev/apiserver-runtime v0.0.0-20251017224250-220a8896ee57 h1:p2e github.com/apoxy-dev/apiserver-runtime v0.0.0-20251017224250-220a8896ee57/go.mod h1:k8K1q/QnsxMM7/wsiga/cJWGW/38G907ex7JPFw0B04= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45 h1:SwPk1n/oSVX7YwlNpC9KNH9YaYkcL/k6OfqSGVnxyyI= github.com/apoxy-dev/connect-ip-go v0.0.0-20250530062404-603929a73f45/go.mod h1:z5rtgIizc+/K27UtB0occwZgqg/mz3IqgyUJW8aubbI= -github.com/apoxy-dev/icx v0.12.1 h1:VaczJSdujpsO8NjS0RvxiF55fco+iKZyurcZu4ddeP8= -github.com/apoxy-dev/icx v0.12.1/go.mod h1:QNPhLVUVbbSVSyERjmgGN4K8vzSC6bvZlN0tyflYf0U= +github.com/apoxy-dev/icx v0.14.0 h1:3BXuhRysBsK2isLu7Z3+1pMiySu2eI0Ts5iObw6fp60= +github.com/apoxy-dev/icx v0.14.0/go.mod h1:QNPhLVUVbbSVSyERjmgGN4K8vzSC6bvZlN0tyflYf0U= github.com/apoxy-dev/quic-go v0.0.0-20250530165952-53cca597715e h1:10GIpiVyKoRgCyr0J2TvJtdn17bsFHN+ROWkeVJpcOU= github.com/apoxy-dev/quic-go v0.0.0-20250530165952-53cca597715e/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= github.com/apoxy-dev/upgrade-cli v0.0.0-20240213232412-a56c3a52fa0e h1:FBNxMQD93z2ththupB/BYKLEaMWaEr+G+sJWJqU2wC4= diff --git a/pkg/cmd/alpha/tunnel_relay.go b/pkg/cmd/alpha/tunnel_relay.go index b51209f5..7ca36896 100644 --- a/pkg/cmd/alpha/tunnel_relay.go +++ b/pkg/cmd/alpha/tunnel_relay.go @@ -14,6 +14,7 @@ import ( "github.com/apoxy-dev/apoxy/pkg/cryptoutils" "github.com/apoxy-dev/apoxy/pkg/tunnel" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/controllers" "github.com/apoxy-dev/apoxy/pkg/tunnel/hasher" @@ -49,11 +50,16 @@ var tunnelRelayCmd = &cobra.Command{ } // One UDP socket shared between Geneve (data) and QUIC (control). - pc, err := net.ListenPacket("udp", listenAddress) + lis, err := net.ListenPacket("udp", listenAddress) if err != nil { return fmt.Errorf("failed to create UDP listener: %w", err) } + pc, err := batchpc.New("udp", lis) + if err != nil { + return fmt.Errorf("failed to create batch packet conn: %w", err) + } + pcGeneve, pcQuic := bifurcate.Bifurcate(pc) defer pcGeneve.Close() defer pcQuic.Close() diff --git a/pkg/cmd/alpha/tunnel_run.go b/pkg/cmd/alpha/tunnel_run.go index 4fccae05..21181b23 100644 --- a/pkg/cmd/alpha/tunnel_run.go +++ b/pkg/cmd/alpha/tunnel_run.go @@ -21,6 +21,7 @@ import ( "github.com/apoxy-dev/apoxy/pkg/netstack" "github.com/apoxy-dev/apoxy/pkg/tunnel/api" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/conntrackpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/router" @@ -47,11 +48,15 @@ var tunnelRunCmd = &cobra.Command{ } // One UDP socket shared between Geneve (data) and QUIC (control). - pc, err := net.ListenPacket("udp", ":0") + lis, err := net.ListenPacket("udp", ":0") if err != nil { return fmt.Errorf("failed to create UDP socket: %w", err) } - defer pc.Close() + + pc, err := batchpc.New("udp", lis) + if err != nil { + return fmt.Errorf("failed to create batch packet conn: %w", err) + } pcGeneve, pcQuic := bifurcate.Bifurcate(pc) defer pcGeneve.Close() diff --git a/pkg/netstack/icx_network.go b/pkg/netstack/icx_network.go index 12152724..112853ee 100644 --- a/pkg/netstack/icx_network.go +++ b/pkg/netstack/icx_network.go @@ -1,3 +1,4 @@ +// icx_network.go package netstack import ( @@ -28,22 +29,26 @@ import ( stdnet "net" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/l2pc" ) // TODO (dpeckett): nuke this at some point and merge the logic into the router. type ICXNetwork struct { network.Network - handler *icx.Handler - phy *l2pc.L2PacketConn - ep *channel.Endpoint - stack *stack.Stack - ipt *IPTables - nicID tcpip.NICID - pcapFile *os.File - incomingPacket chan *buffer.View - pktPool sync.Pool - closeOnce sync.Once + handler *icx.Handler + phy *l2pc.L2PacketConn + ep *channel.Endpoint + stack *stack.Stack + ipt *IPTables + nicID tcpip.NICID + pcapFile *os.File + + // Wakeup channel for batching outbound sends. Capacity 1 to coalesce notifies. + wakeOutbound chan struct{} + + pktPool sync.Pool + closeOnce sync.Once } // NewICXNetwork creates a new ICXNetwork instance with the given handler, physical connection, MTU, and resolve configuration. @@ -107,15 +112,17 @@ func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, mtu int, resolv }) net := &ICXNetwork{ - Network: network.Netstack(ipstack, nicID, resolveConf), - handler: handler, - phy: phy, - ep: linkEP, - stack: ipstack, - ipt: ipt, - nicID: nicID, - pcapFile: pcapFile, - incomingPacket: make(chan *buffer.View), + Network: network.Netstack(ipstack, nicID, resolveConf), + handler: handler, + phy: phy, + ep: linkEP, + stack: ipstack, + ipt: ipt, + nicID: nicID, + pcapFile: pcapFile, + + wakeOutbound: make(chan struct{}, 1), + pktPool: sync.Pool{ New: func() any { b := make([]byte, 0, 65535) @@ -129,14 +136,13 @@ func NewICXNetwork(handler *icx.Handler, phy *l2pc.L2PacketConn, mtu int, resolv } // WriteNotify is called by the channel endpoint when netstack has an outbound packet ready. +// We just coalesce a wakeup; actual draining/batching happens in the outbound pump. func (net *ICXNetwork) WriteNotify() { - pkt := net.ep.Read() - if pkt == nil { - return + select { + case net.wakeOutbound <- struct{}{}: + default: + // already awake; coalesce } - view := pkt.ToView() - pkt.DecRef() - net.incomingPacket <- view } // Close cleans up the network stack and closes the underlying resources. @@ -146,8 +152,8 @@ func (net *ICXNetwork) Close() error { net.ep.Close() - if net.incomingPacket != nil { - close(net.incomingPacket) + if net.wakeOutbound != nil { + close(net.wakeOutbound) } if net.pcapFile != nil { @@ -159,47 +165,105 @@ func (net *ICXNetwork) Close() error { } // Start copies packets to and from netstack and icx. -// This is a blocking call that runs until either side is closed. +// Upstream (netstack -> phy) uses batched I/O; downstream remains as before. func (net *ICXNetwork) Start() error { + const tickMs = 100 // periodically flush scheduled frames (ToPhy) + var g errgroup.Group - // Outbound: netstack (L3) -> ICX -> L2PacketConn.WriteFrame + // Outbound: netstack (L3) -> ICX -> L2PacketConn.WriteBatchFrames (batched) g.Go(func() error { - // Avoid a busy loop. - ticker := time.NewTicker(100 * time.Millisecond) + type owned struct { + msg batchpc.Message + buf *[]byte // owner to return to pool + } + putOwned := func(v []owned) { + for i := range v { + if v[i].buf != nil { + net.pktPool.Put(v[i].buf) + v[i].buf = nil + } + } + } + + // Reuse per-iteration scratch arrays to avoid allocs. + // Assumes batchpc.MaxBatchSize is a const. + var ( + batchOwned [batchpc.MaxBatchSize]owned + batchMsgs [batchpc.MaxBatchSize]batchpc.Message + ) + + ticker := time.NewTicker(tickMs * time.Millisecond) defer ticker.Stop() for { + // Wake on notify or periodic tick. select { - case view, ok := <-net.incomingPacket: + case _, ok := <-net.wakeOutbound: if !ok { - return stdnet.ErrClosed // channel closed => done + return stdnet.ErrClosed } + case <-ticker.C: + } - ip := view.AsSlice() // raw IP bytes (v4 or v6) + batch := batchOwned[:] + count := 0 - phyFrame := net.pktPool.Get().(*[]byte) - *phyFrame = (*phyFrame)[:cap(*phyFrame)] - n, _ := net.handler.VirtToPhy(ip, *phyFrame) - if n > 0 { - if err := net.phy.WriteFrame((*phyFrame)[:n]); err != nil { - net.pktPool.Put(phyFrame) - return fmt.Errorf("writing phy frame failed: %w", err) - } + // Drain endpoint fully into the batch. + for count < batchpc.MaxBatchSize { + pkt := net.ep.Read() + if pkt == nil { + break } - net.pktPool.Put(phyFrame) + view := pkt.ToView() + pkt.DecRef() + + ip := view.AsSlice() // raw L3 bytes + b := batch[count].buf + if b == nil { + b = net.pktPool.Get().(*[]byte) + *b = (*b)[:cap(*b)] + batch[count].buf = b + } + *b = (*b)[:cap(*b)] + if n, _ := net.handler.VirtToPhy(ip, *b); n > 0 { + batch[count].msg.Buf = (*b)[:n] + count++ + } + } - case <-ticker.C: - phyFrame := net.pktPool.Get().(*[]byte) - *phyFrame = (*phyFrame)[:cap(*phyFrame)] - - if n := net.handler.ToPhy(*phyFrame); n > 0 { - if err := net.phy.WriteFrame((*phyFrame)[:n]); err != nil { - net.pktPool.Put(phyFrame) - return fmt.Errorf("writing scheduled phy frame failed: %w", err) - } + // Coalesce scheduled frames (ToPhy) onto the same batch. + for count < batchpc.MaxBatchSize { + b := batch[count].buf + if b == nil { + b = net.pktPool.Get().(*[]byte) + *b = (*b)[:cap(*b)] + batch[count].buf = b + } + *b = (*b)[:cap(*b)] + if n := net.handler.ToPhy(*b); n > 0 { + batch[count].msg.Buf = (*b)[:n] + count++ + } else { + break } - net.pktPool.Put(phyFrame) + } + + if count == 0 { + // Nothing to send this cycle. + putOwned(batch) + continue + } + + // Send in one go. + msgs := batchMsgs[:count] + for i := 0; i < count; i++ { + msgs[i] = batch[i].msg + } + n, err := net.phy.WriteBatchFrames(msgs, 0) + putOwned(batch) + if err != nil { + return fmt.Errorf("writing batched phy frames failed after %d/%d: %w", n, count, err) } } }) @@ -247,8 +311,8 @@ func (net *ICXNetwork) Start() error { pkb.DecRef() default: // drop silently - net.pktPool.Put(virtFrame) } + net.pktPool.Put(virtFrame) } }) @@ -332,24 +396,3 @@ func (net *ICXNetwork) ForwardTo(ctx context.Context, upstream network.Network) return nil } - -func prefixToSubnet(p netip.Prefix) (tcpip.Subnet, error) { - addr := tcpip.AddrFromSlice(p.Addr().AsSlice()) - - totalBits := 128 - if p.Addr().Is4() { - totalBits = 32 - } - ones := p.Bits() - if ones < 0 || ones > totalBits { - return tcpip.Subnet{}, fmt.Errorf("invalid prefix length %d", ones) - } - - maskBytes := make([]byte, totalBits/8) - for i := 0; i < ones; i++ { - maskBytes[i/8] |= 1 << (7 - uint(i%8)) - } - mask := tcpip.MaskFromBytes(maskBytes) - - return tcpip.NewSubnet(addr, mask) -} diff --git a/pkg/netstack/icx_network_test.go b/pkg/netstack/icx_network_test.go index 1f908e12..a99c2e9b 100644 --- a/pkg/netstack/icx_network_test.go +++ b/pkg/netstack/icx_network_test.go @@ -19,6 +19,7 @@ import ( "github.com/apoxy-dev/icx" "github.com/apoxy-dev/apoxy/pkg/netstack" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" "github.com/apoxy-dev/apoxy/pkg/tunnel/l2pc" ) @@ -31,12 +32,18 @@ func TestICXNetwork_Speed(t *testing.T) { slog.SetLogLoggerLevel(slog.LevelDebug) // Create two underlying UDP packet conns on localhost - pcA, err := net.ListenPacket("udp", "127.0.0.1:0") + connA, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + pcA, err := batchpc.New("udp4", connA) require.NoError(t, err) pcAGeneve, _ := bifurcate.Bifurcate(pcA) - pcB, err := net.ListenPacket("udp", "127.0.0.1:0") + connB, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + pcB, err := batchpc.New("udp4", connB) require.NoError(t, err) pcBGeneve, _ := bifurcate.Bifurcate(pcB) diff --git a/pkg/tunnel/batchpc/batchpc.go b/pkg/tunnel/batchpc/batchpc.go new file mode 100644 index 00000000..4c0a41c1 --- /dev/null +++ b/pkg/tunnel/batchpc/batchpc.go @@ -0,0 +1,224 @@ +package batchpc + +import ( + "fmt" + "net" + "sync" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// MaxBatchSize is the maximum number of packets that can be read/written in a single batch. +const MaxBatchSize = 64 + +// Message represents a single packet for batched I/O. +type Message struct { + Buf []byte + Addr net.Addr +} + +// BatchPacketConn is a PacketConn with batched I/O using Messages. +type BatchPacketConn interface { + net.PacketConn + ReadBatch(msgs []Message, flags int) (int, error) + WriteBatch(msgs []Message, flags int) (int, error) +} + +// New creates a pooled BatchPacketConn wrapping a UDP PacketConn. +// network must be one of: "udp", "udp4", "udp6" (empty treated as "udp"). +// If network == "udp", we infer from LocalAddr(); if ambiguous we prefer IPv6. +// Only *net.UDPConn is supported. +func New(network string, pc net.PacketConn) (BatchPacketConn, error) { + uc, ok := pc.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("batchudp: only *net.UDPConn is supported") + } + return newFromUDPConn(network, uc) +} + +func resolveNetwork(network string, pc net.PacketConn) (string, error) { + switch network { + case "", "udp": + // Infer from LocalAddr if possible; prefer IPv6 when ambiguous. + if ua, _ := pc.LocalAddr().(*net.UDPAddr); ua != nil && ua.IP != nil { + if ua.IP.To4() != nil { + return "udp4", nil + } + return "udp6", nil + } + return "udp6", nil + case "udp4": + return "udp4", nil + case "udp6": + return "udp6", nil + default: + return "", fmt.Errorf("batchudp: unsupported network %q (want udp, udp4, udp6)", network) + } +} + +func newFromUDPConn(network string, pc *net.UDPConn) (BatchPacketConn, error) { + nw, err := resolveNetwork(network, pc) + if err != nil { + return nil, err + } + switch nw { + case "udp4": + return newBatch4(pc), nil + case "udp6": + return newBatch6(pc), nil + default: + // unreachable due to resolveNetwork + return nil, fmt.Errorf("batchudp: unknown network %q", nw) + } +} + +// IPv4 implementation. +type batch4 struct { + net.PacketConn // for net.PacketConn interface + ipv4pc *ipv4.PacketConn // for batch I/O + msgPool sync.Pool +} + +func newBatch4(pc net.PacketConn) *batch4 { + return &batch4{ + PacketConn: pc, + ipv4pc: ipv4.NewPacketConn(pc), + msgPool: sync.Pool{ + New: func() any { + s := make([]ipv4.Message, MaxBatchSize) + return &s + }, + }, + } +} + +func (b *batch4) getTmp(n int) *[]ipv4.Message { + ps := b.msgPool.Get().(*[]ipv4.Message) + if cap(*ps) < n { + // grow once; keep for reuse (amortized) + ns := make([]ipv4.Message, n) + *ps = ns + } + *ps = (*ps)[:n] + // zero out fields we set (only Buffers/Addr/N are touched by kernel) + for i := range *ps { + (*ps)[i].Buffers = (*ps)[i].Buffers[:0] + (*ps)[i].Addr = nil + (*ps)[i].N = 0 + } + return ps +} + +func (b *batch4) putTmp(ps *[]ipv4.Message) { b.msgPool.Put(ps) } + +func (b *batch4) ReadBatch(msgs []Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + tmp := b.getTmp(len(msgs)) + for i := range msgs { + (*tmp)[i].Buffers = [][]byte{msgs[i].Buf} + } + n, err := b.ipv4pc.ReadBatch(*tmp, flags) + if n > 0 { + for i := 0; i < n; i++ { + if len((*tmp)[i].Buffers) > 0 { + msgs[i].Buf = (*tmp)[i].Buffers[0][:(*tmp)[i].N] + } else { + msgs[i].Buf = msgs[i].Buf[:0] + } + msgs[i].Addr = (*tmp)[i].Addr + } + } + b.putTmp(tmp) + return n, err +} + +func (b *batch4) WriteBatch(msgs []Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + tmp := b.getTmp(len(msgs)) + for i := range msgs { + (*tmp)[i].Buffers = [][]byte{msgs[i].Buf} + (*tmp)[i].Addr = msgs[i].Addr + } + n, err := b.ipv4pc.WriteBatch(*tmp, flags) + b.putTmp(tmp) + return n, err +} + +// IPv6 implementation. +type batch6 struct { + net.PacketConn + ipv6pc *ipv6.PacketConn + msgPool sync.Pool +} + +func newBatch6(pc net.PacketConn) *batch6 { + return &batch6{ + PacketConn: pc, + ipv6pc: ipv6.NewPacketConn(pc), + msgPool: sync.Pool{ + New: func() any { + s := make([]ipv6.Message, MaxBatchSize) + return &s + }, + }, + } +} + +func (b *batch6) getTmp(n int) *[]ipv6.Message { + ps := b.msgPool.Get().(*[]ipv6.Message) + if cap(*ps) < n { + ns := make([]ipv6.Message, n) + *ps = ns + } + *ps = (*ps)[:n] + for i := range *ps { + (*ps)[i].Buffers = (*ps)[i].Buffers[:0] + (*ps)[i].Addr = nil + (*ps)[i].N = 0 + } + return ps +} + +func (b *batch6) putTmp(ps *[]ipv6.Message) { b.msgPool.Put(ps) } + +func (b *batch6) ReadBatch(msgs []Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + tmp := b.getTmp(len(msgs)) + for i := range msgs { + (*tmp)[i].Buffers = [][]byte{msgs[i].Buf} + } + n, err := b.ipv6pc.ReadBatch(*tmp, flags) + if n > 0 { + for i := 0; i < n; i++ { + if len((*tmp)[i].Buffers) > 0 { + msgs[i].Buf = (*tmp)[i].Buffers[0][:(*tmp)[i].N] + } else { + msgs[i].Buf = msgs[i].Buf[:0] + } + msgs[i].Addr = (*tmp)[i].Addr + } + } + b.putTmp(tmp) + return n, err +} + +func (b *batch6) WriteBatch(msgs []Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + tmp := b.getTmp(len(msgs)) + for i := range msgs { + (*tmp)[i].Buffers = [][]byte{msgs[i].Buf} + (*tmp)[i].Addr = msgs[i].Addr + } + n, err := b.ipv6pc.WriteBatch(*tmp, flags) + b.putTmp(tmp) + return n, err +} diff --git a/pkg/tunnel/batchpc/batchpc_test.go b/pkg/tunnel/batchpc/batchpc_test.go new file mode 100644 index 00000000..4cfa0a14 --- /dev/null +++ b/pkg/tunnel/batchpc/batchpc_test.go @@ -0,0 +1,227 @@ +package batchpc_test + +import ( + "errors" + "net" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" +) + +func TestBatchPC(t *testing.T) { + t.Run("unsupported network", func(t *testing.T) { + pc := makeUDPListener(t, "udp4", "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + _, err := batchpc.New("tcp", pc) + require.Error(t, err) + require.ErrorContains(t, err, "unsupported network") + }) + + t.Run("non-udpconn not supported", func(t *testing.T) { + // Use a dummy PacketConn to confirm we reject non-*net.UDPConn. + pc, err := net.ListenPacket("udp4", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + // Wrap in a custom type that is NOT *net.UDPConn + type justPC struct{ net.PacketConn } + _, err = batchpc.New("udp4", justPC{PacketConn: pc}) + require.Error(t, err) + require.ErrorContains(t, err, "only *net.UDPConn is supported") + }) + + t.Run("zero-length Read/Write", func(t *testing.T) { + pc := makeUDPListener(t, "udp4", "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + bc := mustBatch(t, "udp4", pc) + + n, err := bc.ReadBatch(nil, 0) + require.NoError(t, err) + require.Equal(t, 0, n) + + n, err = bc.WriteBatch(nil, 0) + require.NoError(t, err) + require.Equal(t, 0, n) + }) + + type scenario struct { + name string + netListen string // actual socket family + addr string + hintSrv string + hintCli string + counts []int // how many packets per send burst + } + tests := []scenario{ + { + name: "ipv4-direct-explicit", + netListen: "udp4", addr: "127.0.0.1:0", + hintSrv: "udp4", hintCli: "udp4", + counts: []int{4}, + }, + { + name: "ipv6-direct-explicit", + netListen: "udp6", addr: "[::1]:0", + hintSrv: "udp6", hintCli: "udp6", + counts: []int{3}, + }, + { + name: "ipv4-infer-empty-and-udp", + netListen: "udp4", addr: "127.0.0.1:0", + hintSrv: "", hintCli: "udp", // exercises resolveNetwork inference + counts: []int{4}, + }, + { + name: "ipv6-infer-empty-and-udp", + netListen: "udp6", addr: "[::1]:0", + hintSrv: "", hintCli: "udp", + counts: []int{4}, + }, + { + name: "ipv4-pool-grows-beyond-MaxBatchSize", + netListen: "udp4", addr: "127.0.0.1:0", + hintSrv: "udp4", hintCli: "udp4", + // send once with MaxBatchSize+1 to cover pool growth path + counts: []int{batchpc.MaxBatchSize + 1}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + // Server + srvUDP := makeUDPListener(t, tc.netListen, tc.addr) + t.Cleanup(func() { require.NoError(t, srvUDP.Close()) }) + srv, err := batchpc.New(tc.hintSrv, srvUDP) + require.NoError(t, err) + + // Client + cliUDP := makeUDPListener(t, tc.netListen, tc.addr) + t.Cleanup(func() { require.NoError(t, cliUDP.Close()) }) + cli, err := batchpc.New(tc.hintCli, cliUDP) + require.NoError(t, err) + + total := 0 + for _, c := range tc.counts { + total += c + } + done := startEcho(t, srv, total) + + // Send bursts to server and read replies + serverAddr := srv.LocalAddr() + + // Writable window + for _, burst := range tc.counts { + out := make([]batchpc.Message, burst) + for i := 0; i < burst; i++ { + out[i] = batchpc.Message{ + Buf: []byte(tc.name + "/" + strconv.Itoa(i)), + Addr: serverAddr, + } + } + _ = cli.SetWriteDeadline(time.Now().Add(3 * time.Second)) + wn, err := cli.WriteBatch(out, 0) + require.NoError(t, err) + require.Equal(t, burst, wn) + } + + // Read back all replies (order-preserving within batch) + in := make([]batchpc.Message, total) + for i := range in { + in[i].Buf = make([]byte, 1500) + } + got := 0 + dead := time.Now().Add(5 * time.Second) + for got < total && time.Now().Before(dead) { + _ = cli.SetReadDeadline(time.Now().Add(3 * time.Second)) + rn, err := cli.ReadBatch(in[got:], 0) + require.NoError(t, err) + require.Greater(t, rn, 0) + got += rn + } + require.Equal(t, total, got, "didn't receive all echoes") + + // Basic validation of content trim and Addr non-nil. + for i := 0; i < got; i++ { + require.NotEmpty(t, in[i].Buf) + require.NotNil(t, in[i].Addr) + } + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("echo server did not finish") + } + }) + } +} + +func makeUDPListener(t *testing.T, network, addr string) *net.UDPConn { + t.Helper() + pc, err := net.ListenPacket(network, addr) + if err != nil { + t.Skipf("skip %s: %v", network, err) + } + uc, ok := pc.(*net.UDPConn) + require.True(t, ok, "expected *net.UDPConn, got %T", pc) + return uc +} + +// startEcho spins an echo loop that reads batches and writes them back. +// It stops after echoing 'want' packets total. +func startEcho(t *testing.T, bc batchpc.BatchPacketConn, want int) chan struct{} { + t.Helper() + done := make(chan struct{}) + go func() { + defer close(done) + left := want + bufs := make([]batchpc.Message, 0, 64) + for left > 0 { + // resize receive window to what's left (bounded) + nwin := left + if nwin > 32 { + nwin = 32 + } + if cap(bufs) < nwin { + bufs = make([]batchpc.Message, 0, nwin) + } + bufs = bufs[:nwin] + for i := range bufs { + bufs[i].Buf = make([]byte, 1500) + bufs[i].Addr = nil + } + _ = bc.SetReadDeadline(time.Now().Add(3 * time.Second)) + rn, err := bc.ReadBatch(bufs, 0) + if errors.Is(err, net.ErrClosed) { + return + } + require.NoError(t, err) + require.Greater(t, rn, 0) + + out := make([]batchpc.Message, rn) + for i := 0; i < rn; i++ { + out[i] = batchpc.Message{ + Buf: append([]byte(nil), bufs[i].Buf...), // exact copy/length + Addr: bufs[i].Addr, + } + } + _ = bc.SetWriteDeadline(time.Now().Add(3 * time.Second)) + wn, err := bc.WriteBatch(out, 0) + require.NoError(t, err) + require.Equal(t, rn, wn) + + left -= rn + } + }() + return done +} + +func mustBatch(t *testing.T, network string, pc net.PacketConn) batchpc.BatchPacketConn { + t.Helper() + bc, err := batchpc.New(network, pc) + require.NoError(t, err) + return bc +} diff --git a/pkg/tunnel/bifurcate/bifurcate.go b/pkg/tunnel/bifurcate/bifurcate.go index 4ecbbfe2..5b3e9c20 100644 --- a/pkg/tunnel/bifurcate/bifurcate.go +++ b/pkg/tunnel/bifurcate/bifurcate.go @@ -1,99 +1,114 @@ package bifurcate import ( - "net" "sync" "github.com/apoxy-dev/icx/geneve" "gvisor.dev/gvisor/pkg/tcpip/header" -) -type packet struct { - buf []byte - addr net.Addr -} + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" +) -var packetPool = sync.Pool{ +var messagePool = sync.Pool{ New: func() any { - buf := make([]byte, 65535) - return &packet{ - buf: buf, - addr: nil, - } + return &batchpc.Message{Buf: make([]byte, 65535)} }, } // Bifurcate splits incoming packets from `pc` into geneve and other channels. -func Bifurcate(pc net.PacketConn) (net.PacketConn, net.PacketConn) { +func Bifurcate(pc batchpc.BatchPacketConn) (batchpc.BatchPacketConn, batchpc.BatchPacketConn) { geneveConn := newChanPacketConn(pc) otherConn := newChanPacketConn(pc) // Local copies we can nil out when a side is closed. geneveCh := geneveConn.ch otherCh := otherConn.ch - geneveClosed := geneveConn.closed - otherClosed := otherConn.closed + var geneveClosed <-chan struct{} = geneveConn.closed + var otherClosed <-chan struct{} = otherConn.closed go func() { + // Reusable read batch (values) for kernel I/O. + msgs := make([]batchpc.Message, batchpc.MaxBatchSize) + // Shadow array of pooled message pointers we own & recycle. + pm := make([]*batchpc.Message, batchpc.MaxBatchSize) + for { // If both sides are gone, stop. if geneveCh == nil && otherCh == nil { return } - // Reuse packet buffer - p := packetPool.Get().(*packet) - p.buf = p.buf[:cap(p.buf)] + // Prepare buffers for a full batch read. + for i := range msgs { + if pm[i] == nil { + pm[i] = messagePool.Get().(*batchpc.Message) + } + // Reset/expand the buffer we hand to the kernel. + pm[i].Buf = pm[i].Buf[:cap(pm[i].Buf)] + pm[i].Addr = nil + + msgs[i].Buf = pm[i].Buf + msgs[i].Addr = nil + } - n, addr, err := pc.ReadFrom(p.buf) + n, err := pc.ReadBatch(msgs, 0) if err != nil { - packetPool.Put(p) - // Propagate underlying error/closure to both children. + // Return any outstanding pooled messages. + for i := 0; i < len(pm); i++ { + if pm[i] != nil { + messagePool.Put(pm[i]) + pm[i] = nil + } + } _ = geneveConn.Close() _ = otherConn.Close() return } + if n == 0 { + continue + } - p.addr = addr - p.buf = p.buf[:n] + // Classify into destination batches (slices referencing pooled messages). + gBatch := make([]*batchpc.Message, 0, n) + oBatch := make([]*batchpc.Message, 0, n) - if isGeneve(p.buf) { - for { - // If that side is closed, drop the packet. - if geneveCh == nil { - packetPool.Put(p) - break - } - select { - case geneveCh <- p: - // delivered - break - case <-geneveClosed: - // Stop sending to this side going forward. - geneveCh = nil - geneveClosed = nil - // try loop again, which will drop since geneveCh==nil - continue - } - break + for i := 0; i < n; i++ { + m := pm[i] + // msgs[i].Buf has been resized by underlying BatchPacketConn ReadBatch. + m.Buf = msgs[i].Buf + m.Addr = msgs[i].Addr + + if isGeneve(m.Buf) { + gBatch = append(gBatch, m) + } else { + oBatch = append(oBatch, m) } - } else { - for { - if otherCh == nil { - packetPool.Put(p) - break - } - select { - case otherCh <- p: - break - case <-otherClosed: - otherCh = nil - otherClosed = nil - continue + + // Detach so we don't double-put on error paths. + pm[i] = nil + } + + // Helper to send a batch or recycle if receiver closed. + sendBatch := func(ch chan []*batchpc.Message, closed <-chan struct{}, batch []*batchpc.Message) (chan []*batchpc.Message, <-chan struct{}) { + if ch == nil || len(batch) == 0 { + return ch, closed + } + select { + case ch <- batch: + // Delivered; ownership of messages transfers to receiver. + case <-closed: + // Receiver closed: recycle messages. + for _, m := range batch { + messagePool.Put(m) } - break + ch = nil + closed = nil } + return ch, closed } + + geneveCh, geneveClosed = sendBatch(geneveCh, geneveClosed, gBatch) + otherCh, otherClosed = sendBatch(otherCh, otherClosed, oBatch) } }() @@ -113,7 +128,9 @@ func isGeneve(b []byte) bool { } // Check for valid protocol types (IPv4 or IPv6) or EtherType unknown (out-of-band messages). - if hdr.ProtocolType != uint16(header.IPv4ProtocolNumber) && hdr.ProtocolType != uint16(header.IPv6ProtocolNumber) && hdr.ProtocolType != 0 { + if hdr.ProtocolType != uint16(header.IPv4ProtocolNumber) && + hdr.ProtocolType != uint16(header.IPv6ProtocolNumber) && + hdr.ProtocolType != 0 { return false } diff --git a/pkg/tunnel/bifurcate/bifurcate_test.go b/pkg/tunnel/bifurcate/bifurcate_test.go index b806e5bb..8ca2e8aa 100644 --- a/pkg/tunnel/bifurcate/bifurcate_test.go +++ b/pkg/tunnel/bifurcate/bifurcate_test.go @@ -4,83 +4,196 @@ import ( "bytes" "errors" "net" + "sync" "testing" "time" "github.com/apoxy-dev/icx/geneve" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gvisor.dev/gvisor/pkg/tcpip/header" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/bifurcate" ) -type MockPacketConn struct { - mock.Mock - readQueue chan readResult - addr net.Addr - closed bool -} +func TestBifurcate_RoutesWithReadFrom(t *testing.T) { + mockConn := newMockBatchPacketConn() + remote := &net.UDPAddr{IP: net.IPv4(10, 1, 1, 1), Port: 9999} -type readResult struct { - data []byte - addr net.Addr - err error + genevePkt := createGenevePacket(t) + nonGenevePkt := createNonGenevePacket() + + mockConn.enqueue(genevePkt, remote) + mockConn.enqueue(nonGenevePkt, remote) + + geneveConn, otherConn := bifurcate.Bifurcate(mockConn) + + // Read Geneve + buf := make([]byte, 1024) + n, addr, err := geneveConn.ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, remote.String(), addr.String()) + require.True(t, bytes.Equal(buf[:n], genevePkt)) + + // Read other + n, addr, err = otherConn.ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, remote.String(), addr.String()) + require.Equal(t, string(nonGenevePkt), string(buf[:n])) } -func NewMockPacketConn() *MockPacketConn { - return &MockPacketConn{ - readQueue: make(chan readResult, 10), - addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, +func TestBifurcate_ReadBatch_RoutesBatchesToBoth(t *testing.T) { + mockConn := newMockBatchPacketConn() + remoteG := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 10001} + remoteO := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 10002} + + genevePkt := createGenevePacket(t) + nonGenevePkt := createNonGenevePacket() + + // Enqueue a mixed stream larger than child batch request sizes. + for i := 0; i < 5; i++ { + mockConn.enqueue(genevePkt, remoteG) + } + for i := 0; i < 3; i++ { + mockConn.enqueue(nonGenevePkt, remoteO) } -} -func (m *MockPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { - result, ok := <-m.readQueue - if !ok { - return 0, nil, errors.New("mock read closed") + geneveConn, otherConn := bifurcate.Bifurcate(mockConn) + + // Child buffers + makeMsgs := func(n int, sz int) []batchpc.Message { + msgs := make([]batchpc.Message, n) + for i := range msgs { + msgs[i].Buf = make([]byte, sz) + } + return msgs } - n := copy(p, result.data) - return n, result.addr, result.err + + // Read a batch from both sides + gmsgs := makeMsgs(4, 256) + n1, err := geneveConn.ReadBatch(gmsgs, 0) + require.NoError(t, err) + require.Equal(t, 4, n1) + for i := 0; i < n1; i++ { + require.True(t, bytes.Equal(gmsgs[i].Buf, genevePkt)) + require.Equal(t, remoteG.String(), gmsgs[i].Addr.String()) + } + + omsgs := makeMsgs(2, 256) + n2, err := otherConn.ReadBatch(omsgs, 0) + require.NoError(t, err) + require.Equal(t, 2, n2) + for i := 0; i < n2; i++ { + require.Equal(t, string(nonGenevePkt), string(omsgs[i].Buf)) + require.Equal(t, remoteO.String(), omsgs[i].Addr.String()) + } + + // Read remaining packets from both + gmsgs2 := makeMsgs(8, 256) + n3, err := geneveConn.ReadBatch(gmsgs2, 0) + require.NoError(t, err) + require.Equal(t, 1, n3) // 5 total geneve, 4 already read + require.True(t, bytes.Equal(gmsgs2[0].Buf, genevePkt)) + + omsgs2 := makeMsgs(8, 256) + n4, err := otherConn.ReadBatch(omsgs2, 0) + require.NoError(t, err) + require.Equal(t, 1, n4) // 3 total non-geneve, 2 already read + require.Equal(t, string(nonGenevePkt), string(omsgs2[0].Buf)) } -func (m *MockPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { - args := m.Called(p, addr) - return args.Int(0), args.Error(1) +func TestBifurcate_ChildReadBatchDrainsPending(t *testing.T) { + mockConn := newMockBatchPacketConn() + remote := &net.UDPAddr{IP: net.IPv4(10, 2, 3, 4), Port: 4242} + genevePkt := createGenevePacket(t) + + // Enqueue several geneve packets so the bifurcator sends a whole batch. + for i := 0; i < 6; i++ { + mockConn.enqueue(genevePkt, remote) + } + + geneveConn, _ := bifurcate.Bifurcate(mockConn) + + msgs := make([]batchpc.Message, 8) + for i := range msgs { + msgs[i].Buf = make([]byte, 256) + } + n, err := geneveConn.ReadBatch(msgs, 0) + require.NoError(t, err) + require.Equal(t, 6, n) + for i := 0; i < n; i++ { + require.True(t, bytes.Equal(msgs[i].Buf, genevePkt)) + } } -func (m *MockPacketConn) Close() error { - m.closed = true - close(m.readQueue) - return nil +func TestBifurcate_WriteBatchForwardsToUnderlying(t *testing.T) { + mockConn := newMockBatchPacketConn() + geneveConn, otherConn := bifurcate.Bifurcate(mockConn) + + dst := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 1), Port: 9999} + payloads := [][]byte{ + []byte("a"), + []byte("bb"), + []byte("ccc"), + } + msgs := make([]batchpc.Message, len(payloads)) + for i := range msgs { + msgs[i].Buf = payloads[i] + msgs[i].Addr = dst + } + + // Send via geneve child + n1, err := geneveConn.WriteBatch(msgs, 0) + require.NoError(t, err) + require.Equal(t, len(payloads), n1) + + // Send via other child + n2, err := otherConn.WriteBatch(msgs, 0) + require.NoError(t, err) + require.Equal(t, len(payloads), n2) + + // Verify underlying was called and captured content. + mockConn.mu.Lock() + defer mockConn.mu.Unlock() + require.GreaterOrEqual(t, mockConn.writeBatchCalls, 2) + require.Len(t, mockConn.lastWriteBatch, len(payloads)) + for i := range payloads { + require.Equal(t, string(payloads[i]), string(mockConn.lastWriteBatch[i])) + require.Equal(t, dst.String(), mockConn.lastWriteBatchTo[i].String()) + } } -func (m *MockPacketConn) LocalAddr() net.Addr { return m.addr } -func (m *MockPacketConn) SetDeadline(t time.Time) error { return nil } -func (m *MockPacketConn) SetReadDeadline(t time.Time) error { return nil } -func (m *MockPacketConn) SetWriteDeadline(t time.Time) error { return nil } +func TestBifurcate_ClosesBothOnReadError(t *testing.T) { + mockConn := newMockBatchPacketConn() + // Simulate read error: close the queue so next read fails. + close(mockConn.readQueue) + + geneveConn, otherConn := bifurcate.Bifurcate(mockConn) + + // Give the goroutine a breath to observe the close. + time.Sleep(50 * time.Millisecond) -// --- Helpers --- + buf := make([]byte, 1024) + _, _, err := geneveConn.ReadFrom(buf) + require.ErrorIs(t, err, net.ErrClosed) + + _, _, err = otherConn.ReadFrom(buf) + require.ErrorIs(t, err, net.ErrClosed) +} func createGenevePacket(t *testing.T) []byte { - header := geneve.Header{ + h := geneve.Header{ Version: 0, ProtocolType: uint16(header.IPv4ProtocolNumber), VNI: 0x123456, NumOptions: 2, Options: [2]geneve.Option{ - { - Class: geneve.ClassExperimental, - Type: 1, - }, - { - Class: geneve.ClassExperimental, - Type: 2, - }, + {Class: geneve.ClassExperimental, Type: 1}, + {Class: geneve.ClassExperimental, Type: 2}, }, } buf := make([]byte, 128) - n, err := header.MarshalBinary(buf) + n, err := h.MarshalBinary(buf) require.NoError(t, err) return buf[:n] } @@ -89,50 +202,130 @@ func createNonGenevePacket() []byte { return []byte("this is not a geneve packet") } -func TestBifurcate(t *testing.T) { - t.Run("routes geneve and non-geneve packets to correct connections", func(t *testing.T) { - mockConn := NewMockPacketConn() - remote := &net.UDPAddr{IP: net.IPv4(10, 1, 1, 1), Port: 9999} +type readResult struct { + data []byte + addr net.Addr + err error +} - // Prepare packets - genevePkt := createGenevePacket(t) - nonGenevePkt := createNonGenevePacket() +type mockBatchPacketConn struct { + readQueue chan readResult + addr net.Addr - mockConn.readQueue <- readResult{data: genevePkt, addr: remote} - mockConn.readQueue <- readResult{data: nonGenevePkt, addr: remote} + mu sync.Mutex + closed bool + writeToCalls int + lastWriteToBuf []byte + lastWriteToAddr net.Addr + writeBatchCalls int + lastWriteBatch [][]byte + lastWriteBatchTo []net.Addr +} - geneveConn, otherConn := bifurcate.Bifurcate(mockConn) +func newMockBatchPacketConn() *mockBatchPacketConn { + return &mockBatchPacketConn{ + readQueue: make(chan readResult, 64), + addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, + } +} - // Read from geneveConn - buf := make([]byte, 1024) - n, addr, err := geneveConn.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, remote.String(), addr.String()) - require.True(t, bytes.HasPrefix(buf[:n], genevePkt)) +func (pc *mockBatchPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + // Convenience shim using ReadBatch semantics + msgs := []batchpc.Message{{Buf: p}} + n, err := pc.ReadBatch(msgs, 0) + if n == 0 { + return 0, nil, err + } + return len(msgs[0].Buf), msgs[0].Addr, err +} - // Read from otherConn - n, addr, err = otherConn.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, remote.String(), addr.String()) - require.Equal(t, string(buf[:n]), string(nonGenevePkt)) - }) +func (pc *mockBatchPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + pc.mu.Lock() + pc.writeToCalls++ + pc.lastWriteToBuf = append(pc.lastWriteToBuf[:0], p...) + pc.lastWriteToAddr = addr + pc.mu.Unlock() + return len(p), nil +} - t.Run("closes both connections on read error", func(t *testing.T) { - mockConn := NewMockPacketConn() - // simulate read error by closing channel - close(mockConn.readQueue) +func (pc *mockBatchPacketConn) Close() error { + pc.mu.Lock() + if !pc.closed { + pc.closed = true + close(pc.readQueue) + } + pc.mu.Unlock() + return nil +} - geneveConn, otherConn := bifurcate.Bifurcate(mockConn) +func (pc *mockBatchPacketConn) LocalAddr() net.Addr { + return pc.addr +} - // wait for goroutine to detect closure - time.Sleep(50 * time.Millisecond) +func (pc *mockBatchPacketConn) SetDeadline(t time.Time) error { + return nil +} - buf := make([]byte, 1024) +func (pc *mockBatchPacketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (pc *mockBatchPacketConn) SetWriteDeadline(t time.Time) error { + return nil +} - _, _, err := geneveConn.ReadFrom(buf) - require.ErrorIs(t, err, net.ErrClosed) +func (pc *mockBatchPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + // First result: block + result, ok := <-pc.readQueue + if !ok { + return 0, errors.New("mock read closed") + } + if result.err != nil { + return 0, result.err + } + n0 := copy(msgs[0].Buf, result.data) + msgs[0].Buf = msgs[0].Buf[:n0] + msgs[0].Addr = result.addr + n := 1 + + // Drain non-blocking + for n < len(msgs) { + select { + case rr, ok := <-pc.readQueue: + if !ok { + return n, errors.New("mock read closed") + } + if rr.err != nil { + return n, rr.err + } + cn := copy(msgs[n].Buf, rr.data) + msgs[n].Buf = msgs[n].Buf[:cn] + msgs[n].Addr = rr.addr + n++ + default: + return n, nil + } + } + return n, nil +} + +func (pc *mockBatchPacketConn) WriteBatch(msgs []batchpc.Message, flags int) (int, error) { + pc.mu.Lock() + pc.writeBatchCalls++ + pc.lastWriteBatch = pc.lastWriteBatch[:0] + pc.lastWriteBatchTo = pc.lastWriteBatchTo[:0] + for _, ms := range msgs { + cp := append([]byte(nil), ms.Buf...) + pc.lastWriteBatch = append(pc.lastWriteBatch, cp) + pc.lastWriteBatchTo = append(pc.lastWriteBatchTo, ms.Addr) + } + pc.mu.Unlock() + return len(msgs), nil +} - _, _, err = otherConn.ReadFrom(buf) - require.ErrorIs(t, err, net.ErrClosed) - }) +func (pc *mockBatchPacketConn) enqueue(data []byte, addr net.Addr) { + pc.readQueue <- readResult{data: append([]byte(nil), data...), addr: addr, err: nil} } diff --git a/pkg/tunnel/bifurcate/chanpc.go b/pkg/tunnel/bifurcate/chanpc.go index 4a5fe7b3..c350e462 100644 --- a/pkg/tunnel/bifurcate/chanpc.go +++ b/pkg/tunnel/bifurcate/chanpc.go @@ -3,59 +3,144 @@ package bifurcate import ( "net" "time" + + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" ) type chanPacketConn struct { - pc net.PacketConn // underlying connection - ch chan *packet // incoming packets + pc batchpc.BatchPacketConn + // Incoming batches from the bifurcator goroutine. + ch chan []*batchpc.Message closed chan struct{} + // Locally pending batch from the last receive (not yet fully consumed). + pending []*batchpc.Message + pendingIndex int } -func newChanPacketConn(pc net.PacketConn) *chanPacketConn { +func newChanPacketConn(pc batchpc.BatchPacketConn) *chanPacketConn { return &chanPacketConn{ - ch: make(chan *packet, 1024), + ch: make(chan []*batchpc.Message, 1024), pc: pc, closed: make(chan struct{}), } } -func (c *chanPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - select { - case pkt := <-c.ch: - defer packetPool.Put(pkt) // return packet to pool - n = copy(p, pkt.buf) - return n, pkt.addr, nil - case <-c.closed: - return 0, nil, net.ErrClosed +func (pc *chanPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + if err := pc.ensurePendingBlocking(); err != nil { + return 0, nil, err } + msg := pc.popOne() + defer messagePool.Put(msg) + + n := copy(p, msg.Buf) + return n, msg.Addr, nil } -func (c *chanPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { - return c.pc.WriteTo(p, addr) +func (pc *chanPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + return pc.pc.WriteTo(p, addr) } -func (c *chanPacketConn) Close() error { +func (pc *chanPacketConn) Close() error { select { - case <-c.closed: + case <-pc.closed: return nil default: - close(c.closed) + close(pc.closed) return nil } } -func (c *chanPacketConn) LocalAddr() net.Addr { - return c.pc.LocalAddr() +func (pc *chanPacketConn) LocalAddr() net.Addr { + return pc.pc.LocalAddr() +} + +func (pc *chanPacketConn) SetDeadline(t time.Time) error { + return pc.pc.SetDeadline(t) +} + +func (pc *chanPacketConn) SetReadDeadline(t time.Time) error { + return pc.pc.SetReadDeadline(t) +} + +func (pc *chanPacketConn) SetWriteDeadline(t time.Time) error { + return pc.pc.SetWriteDeadline(t) +} + +func (pc *chanPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + + n := 0 + // 1) Ensure at least one packet (blocking once). + if err := pc.ensurePendingBlocking(); err != nil { + return 0, err + } + + // 2) Fill from pending, then non-blocking drain of further batches. + fill := func() { + for n < len(msgs) && len(pc.pending) > 0 { + msg := pc.popOne() + copied := copy(msgs[n].Buf, msg.Buf) + msgs[n].Buf = msgs[n].Buf[:copied] + msgs[n].Addr = msg.Addr + messagePool.Put(msg) + n++ + } + } + + fill() // consume current pending + + for n < len(msgs) { + if !pc.tryFillPendingNonBlocking() { + break + } + fill() + } + + return n, nil +} + +func (pc *chanPacketConn) WriteBatch(msgs []batchpc.Message, flags int) (int, error) { + return pc.pc.WriteBatch(msgs, flags) } -func (c *chanPacketConn) SetDeadline(t time.Time) error { - return c.pc.SetDeadline(t) +// popOne pulls one message from pending; assumes pending not empty. +func (pc *chanPacketConn) popOne() *batchpc.Message { + m := pc.pending[pc.pendingIndex] + pc.pendingIndex++ + if pc.pendingIndex >= len(pc.pending) { + // batch fully consumed + pc.pending = nil + pc.pendingIndex = 0 + } + return m } -func (c *chanPacketConn) SetReadDeadline(t time.Time) error { - return c.pc.SetReadDeadline(t) +func (pc *chanPacketConn) ensurePendingBlocking() error { + if len(pc.pending) > 0 { + return nil + } + select { + case batch := <-pc.ch: + pc.pending = batch + pc.pendingIndex = 0 + return nil + case <-pc.closed: + return net.ErrClosed + } } -func (c *chanPacketConn) SetWriteDeadline(t time.Time) error { - return c.pc.SetWriteDeadline(t) +func (pc *chanPacketConn) tryFillPendingNonBlocking() bool { + if len(pc.pending) > 0 { + return true + } + select { + case batch := <-pc.ch: + pc.pending = batch + pc.pendingIndex = 0 + return true + default: + return false + } } diff --git a/pkg/tunnel/l2pc/l2pc.go b/pkg/tunnel/l2pc/l2pc.go index c7eca7e0..2b40b03f 100644 --- a/pkg/tunnel/l2pc/l2pc.go +++ b/pkg/tunnel/l2pc/l2pc.go @@ -12,6 +12,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" tunnet "github.com/apoxy-dev/apoxy/pkg/tunnel/net" ) @@ -19,7 +20,7 @@ var ErrInvalidFrame = errors.New("invalid frame") // L2PacketConn adapts a net.PacketConn (UDP) to read/write L2 Ethernet frames. type L2PacketConn struct { - pc net.PacketConn + pc batchpc.BatchPacketConn // now the batched PacketConn (still implements net.PacketConn) localAddrs addrselect.List localMAC tcpip.LinkAddress peerMACCache sync.Map @@ -27,7 +28,7 @@ type L2PacketConn struct { } // NewL2PacketConn creates a new L2PacketConn. -func NewL2PacketConn(pc net.PacketConn) (*L2PacketConn, error) { +func NewL2PacketConn(pc batchpc.BatchPacketConn) (*L2PacketConn, error) { ua, ok := pc.LocalAddr().(*net.UDPAddr) if !ok || ua == nil { return nil, fmt.Errorf("PacketConn must be UDP") @@ -85,54 +86,19 @@ func NewL2PacketConn(pc net.PacketConn) (*L2PacketConn, error) { return c, nil } -func (c *L2PacketConn) Close() error { return c.pc.Close() } +func (c *L2PacketConn) Close() error { + return c.pc.Close() +} // WriteFrame consumes an Ethernet frame (IPv4/IPv6 + UDP) and writes the payload // to the underlying PacketConn based on the frame’s dst IP:port. func (c *L2PacketConn) WriteFrame(frame []byte) error { - if len(frame) < header.EthernetMinimumSize { - return ErrInvalidFrame - } - - eth := header.Ethernet(frame) - switch eth.Type() { - case header.IPv4ProtocolNumber: - ip := header.IPv4(frame[header.EthernetMinimumSize:]) - if !ip.IsValid(len(ip)) || ip.Protocol() != uint8(header.UDPProtocolNumber) { - return ErrInvalidFrame - } - udpHdr := header.UDP(ip.Payload()) - if len(udpHdr) < header.UDPMinimumSize { - return ErrInvalidFrame - } - dst := &net.UDPAddr{ - IP: net.IP(ip.DestinationAddressSlice()), - Port: int(udpHdr.DestinationPort()), - } - payload := udpHdr.Payload() - _, err := c.pc.WriteTo(payload, dst) - return err - - case header.IPv6ProtocolNumber: - ip6 := header.IPv6(frame[header.EthernetMinimumSize:]) - if !ip6.IsValid(len(ip6)) || ip6.TransportProtocol() != header.UDPProtocolNumber { - return ErrInvalidFrame - } - udpHdr := header.UDP(ip6.Payload()) - if len(udpHdr) < header.UDPMinimumSize { - return ErrInvalidFrame - } - dst := &net.UDPAddr{ - IP: net.IP(ip6.DestinationAddressSlice()), - Port: int(udpHdr.DestinationPort()), - } - payload := udpHdr.Payload() - _, err := c.pc.WriteTo(payload, dst) + payload, dst, err := extractUDPPayloadAndDst(frame) + if err != nil { return err - - default: - return fmt.Errorf("unsupported ethertype: %d", eth.Type()) } + _, err = c.pc.WriteTo(payload, dst) + return err } // ReadFrame reads from PacketConn and emits a full Ethernet frame into dst. @@ -148,30 +114,13 @@ func (c *L2PacketConn) ReadFrame(dst []byte) (int, error) { if err != nil { return 0, err } - remote := raddr.(*net.UDPAddr) - - // Decide offset by family; then shift payload to that offset. - payloadOffset := udp.PayloadOffsetIPv4 - isIPv6 := remote.IP.To4() == nil - if isIPv6 { - payloadOffset = udp.PayloadOffsetIPv6 - } - - if payloadOffset+n > cap(*phy) { - return 0, errors.New("packet too large") + remote, ok := raddr.(*net.UDPAddr) + if !ok || remote == nil { + return 0, fmt.Errorf("unexpected remote addr type %T", raddr) } - copy((*phy)[payloadOffset:], (*phy)[:n]) - - // Build addresses for udp.Encode (note: for an inbound frame, - // src = remote, dst = local). - srcFA := toFullAddr(remote) - dstFA := c.localAddrs.Select(srcFA) - - // Random-but-stable (per remote IP) src MAC. - srcFA.LinkAddr = c.peerMACForIP(remote.IP) // Encode the full Ethernet+IP+UDP frame in-place. - frameLen, err := udp.Encode((*phy)[:], srcFA, dstFA, n, false) + frameLen, err := c.encodeInboundFrame((*phy)[:], n, remote) if err != nil { return 0, err } @@ -184,6 +133,115 @@ func (c *L2PacketConn) ReadFrame(dst []byte) (int, error) { return frameLen, nil } +// WriteBatchFrames consumes a batch of Ethernet frames (IPv4/IPv6 + UDP) +// and writes their UDP payloads to destinations extracted from each frame. +// On return, n is the number of frames successfully queued/written. +func (c *L2PacketConn) WriteBatchFrames(msgs []batchpc.Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + + // Prepare the underlying UDP messages (payload + addr). + umsgs := make([]batchpc.Message, len(msgs)) + for i := range msgs { + payload, dst, err := extractUDPPayloadAndDst(msgs[i].Buf) + if err != nil { + return i, err + } + umsgs[i].Buf = payload + umsgs[i].Addr = dst + } + + // Send in one batch. + return c.pc.WriteBatch(umsgs, flags) +} + +// ReadBatchFrames reads a batch of UDP payloads and emits fully-formed +// Ethernet frames into msgs[i].Buf (resizing the slice length to the frame size). +// msgs[i].Addr will be set to the remote *net.UDPAddr for convenience. +func (c *L2PacketConn) ReadBatchFrames(msgs []batchpc.Message, flags int) (int, error) { + if len(msgs) == 0 { + return 0, nil + } + + // Prepare underlying UDP read buffers sourced from our pool. + umsgs := make([]batchpc.Message, len(msgs)) + phys := make([]*[]byte, len(msgs)) // to Put() back after use + + for i := range msgs { + phy := c.pktPool.Get().(*[]byte) + *phy = (*phy)[:cap(*phy)] + phys[i] = phy + umsgs[i].Buf = (*phy)[:] + } + + n, err := c.pc.ReadBatch(umsgs, flags) + + // Return pooled buffers we didn't fill. + for i := n; i < len(phys); i++ { + if phys[i] != nil { + c.pktPool.Put(phys[i]) + phys[i] = nil + } + } + if err != nil && n == 0 { + // If nothing was read, return early with the error. + for i := 0; i < len(phys); i++ { + if phys[i] != nil { + c.pktPool.Put(phys[i]) + } + } + return 0, err + } + + // Translate each UDP payload into a full Ethernet frame and copy to caller buffers. + for i := 0; i < n; i++ { + raddr, ok := umsgs[i].Addr.(*net.UDPAddr) + if !ok || raddr == nil { + // Clean up and return partial progress + error. + for j := i; j < n; j++ { + if phys[j] != nil { + c.pktPool.Put(phys[j]) + } + } + return i, fmt.Errorf("unexpected remote addr type %T", umsgs[i].Addr) + } + + // Encode headers into our scratch buffer. + frameLen, encErr := c.encodeInboundFrame((*phys[i])[:], len(umsgs[i].Buf), raddr) + if encErr != nil { + for j := i; j < n; j++ { + if phys[j] != nil { + c.pktPool.Put(phys[j]) + } + } + return i, encErr + } + + // Copy to caller buffer and set slice length. + if len(msgs[i].Buf) < frameLen { + for j := i; j < n; j++ { + if phys[j] != nil { + c.pktPool.Put(phys[j]) + } + } + return i, errors.New("destination buffer too small") + } + copy(msgs[i].Buf[:frameLen], (*phys[i])[:frameLen]) + msgs[i].Buf = msgs[i].Buf[:frameLen] + msgs[i].Addr = raddr // surface the remote UDP address for diagnostics/metrics + } + + // Return pooled buffers for the successfully processed packets. + for i := 0; i < n; i++ { + if phys[i] != nil { + c.pktPool.Put(phys[i]) + } + } + + return n, err +} + // LocalMAC returns the locally-administered unicast MAC address used by this connection. func (c *L2PacketConn) LocalMAC() tcpip.LinkAddress { return c.localMAC @@ -205,6 +263,69 @@ func (c *L2PacketConn) peerMACForIP(ip net.IP) tcpip.LinkAddress { return newMAC } +// encodeInboundFrame takes an inbound UDP payload (already read from the socket) +// plus its remote addr, and encodes a full Ethernet+IP+UDP frame into buf in place. +// It returns the frame length. buf must be a scratch buffer with capacity >= headers+payload. +func (c *L2PacketConn) encodeInboundFrame(buf []byte, payloadLen int, raddr *net.UDPAddr) (int, error) { + // Decide header room by family and move payload to make space. + payloadOff := udp.PayloadOffsetIPv4 + if raddr.IP.To4() == nil { + payloadOff = udp.PayloadOffsetIPv6 + } + if payloadOff+payloadLen > cap(buf) { + return 0, errors.New("packet too large") + } + // The input layout is [payload ...]; shift it up in-place. + copy(buf[payloadOff:], buf[:payloadLen]) + + // Build addresses for udp.Encode (inbound: src = remote, dst = local). + srcFA := toFullAddr(raddr) + dstFA := c.localAddrs.Select(srcFA) + srcFA.LinkAddr = c.peerMACForIP(raddr.IP) // stable random per peer + + return udp.Encode(buf[:], srcFA, dstFA, payloadLen, false) +} + +// extractUDPPayloadAndDst validates an Ethernet frame (IPv4/IPv6+UDP) +// and returns the UDP payload and destination socket address. +func extractUDPPayloadAndDst(frame []byte) (payload []byte, dst *net.UDPAddr, err error) { + if len(frame) < header.EthernetMinimumSize { + return nil, nil, ErrInvalidFrame + } + eth := header.Ethernet(frame) + switch eth.Type() { + case header.IPv4ProtocolNumber: + ip := header.IPv4(frame[header.EthernetMinimumSize:]) + if !ip.IsValid(len(ip)) || ip.Protocol() != uint8(header.UDPProtocolNumber) { + return nil, nil, ErrInvalidFrame + } + udpHdr := header.UDP(ip.Payload()) + if len(udpHdr) < header.UDPMinimumSize { + return nil, nil, ErrInvalidFrame + } + return udpHdr.Payload(), &net.UDPAddr{ + IP: net.IP(ip.DestinationAddressSlice()), + Port: int(udpHdr.DestinationPort()), + }, nil + + case header.IPv6ProtocolNumber: + ip6 := header.IPv6(frame[header.EthernetMinimumSize:]) + if !ip6.IsValid(len(ip6)) || ip6.TransportProtocol() != header.UDPProtocolNumber { + return nil, nil, ErrInvalidFrame + } + udpHdr := header.UDP(ip6.Payload()) + if len(udpHdr) < header.UDPMinimumSize { + return nil, nil, ErrInvalidFrame + } + return udpHdr.Payload(), &net.UDPAddr{ + IP: net.IP(ip6.DestinationAddressSlice()), + Port: int(udpHdr.DestinationPort()), + }, nil + default: + return nil, nil, fmt.Errorf("unsupported ethertype: %d", eth.Type()) + } +} + func toFullAddr(ua *net.UDPAddr) *tcpip.FullAddress { if ua.IP.To4() != nil { return &tcpip.FullAddress{ diff --git a/pkg/tunnel/l2pc/l2pc_test.go b/pkg/tunnel/l2pc/l2pc_test.go index eb7cbf5f..242db387 100644 --- a/pkg/tunnel/l2pc/l2pc_test.go +++ b/pkg/tunnel/l2pc/l2pc_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/apoxy/pkg/tunnel/l2pc" "github.com/apoxy-dev/icx/udp" "github.com/stretchr/testify/require" @@ -13,59 +14,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -func makeIPv4Frame(srcMAC, dstMAC tcpip.LinkAddress, srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16, payload []byte) []byte { - buf := make([]byte, udp.PayloadOffsetIPv4+len(payload)) - copy(buf[udp.PayloadOffsetIPv4:], payload) - - src := &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(srcIP.To4()), - Port: srcPort, - LinkAddr: srcMAC, - } - dst := &tcpip.FullAddress{ - Addr: tcpip.AddrFrom4Slice(dstIP.To4()), - Port: dstPort, - LinkAddr: dstMAC, - } - n, err := udp.Encode(buf, src, dst, len(payload), false /* calc checksum */) - if err != nil { - panic(err) - } - return buf[:n] -} - -func makeIPv6Frame(srcMAC, dstMAC tcpip.LinkAddress, srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16, payload []byte) []byte { - buf := make([]byte, udp.PayloadOffsetIPv6+len(payload)) - copy(buf[udp.PayloadOffsetIPv6:], payload) - - src := &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(dstTo16(srcIP)), - Port: srcPort, - LinkAddr: srcMAC, - } - dst := &tcpip.FullAddress{ - Addr: tcpip.AddrFrom16Slice(dstTo16(dstIP)), - Port: dstPort, - LinkAddr: dstMAC, - } - n, err := udp.Encode(buf, src, dst, len(payload), false /* calc checksum */) - if err != nil { - panic(err) - } - return buf[:n] -} - -func dstTo16(ip net.IP) []byte { - if ip == nil { - return nil - } - return ip.To16() -} - func TestNewL2PacketConn_UDPOnly(t *testing.T) { - pc, err := net.ListenPacket("udp", "127.0.0.1:0") + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + pc, err := batchpc.New("udp4", conn) require.NoError(t, err) - defer pc.Close() + t.Cleanup(func() { require.NoError(t, pc.Close()) }) c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) @@ -73,30 +28,30 @@ func TestNewL2PacketConn_UDPOnly(t *testing.T) { require.NotEmpty(t, c.LocalMAC()) } -type notUDPConn struct{ net.PacketConn } +type notUDPConn struct{ batchpc.BatchPacketConn } func (n notUDPConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} } func TestNewL2PacketConn_RejectsNonUDP(t *testing.T) { - realPC, err := net.ListenPacket("udp", "127.0.0.1:0") - require.NoError(t, err) - defer realPC.Close() - - _, err = l2pc.NewL2PacketConn(notUDPConn{PacketConn: realPC}) + _, err := l2pc.NewL2PacketConn(notUDPConn{}) require.Error(t, err) require.Contains(t, err.Error(), "PacketConn must be UDP") } func TestWriteFrame_IPv4_UsesPayloadAndDst(t *testing.T) { - pc, err := net.ListenPacket("udp", "127.0.0.1:0") + conn, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer pc.Close() + + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) peer, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer peer.Close() + t.Cleanup(func() { require.NoError(t, peer.Close()) }) dst := peer.LocalAddr().(*net.UDPAddr) @@ -124,14 +79,15 @@ func TestWriteFrame_IPv4_UsesPayloadAndDst(t *testing.T) { func TestWriteFrame_IPv6_UsesPayloadAndDst(t *testing.T) { peer, err := net.ListenPacket("udp", "[::1]:0") - if err != nil { - t.Skip("IPv6 loopback not available:", err) - } - defer peer.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, peer.Close()) }) + + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) - pc, err := net.ListenPacket("udp", "[::1]:0") + pc, err := batchpc.New("udp6", conn) require.NoError(t, err) - defer pc.Close() + t.Cleanup(func() { require.NoError(t, pc.Close()) }) c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) @@ -161,9 +117,13 @@ func TestWriteFrame_IPv6_UsesPayloadAndDst(t *testing.T) { } func TestWriteFrame_InvalidFrames(t *testing.T) { - pc, err := net.ListenPacket("udp", "127.0.0.1:0") + conn, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer pc.Close() + + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) @@ -237,15 +197,19 @@ func TestWriteFrame_InvalidFrames(t *testing.T) { func TestReadFrame_IPv4_EncodesWithUDPEncodeAndStablePeerMAC(t *testing.T) { // Adapter under test - pc, err := net.ListenPacket("udp", "127.0.0.1:0") + conn, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer pc.Close() + + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) peer, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer peer.Close() + t.Cleanup(func() { require.NoError(t, peer.Close()) }) // Send to adapter _, err = peer.WriteTo([]byte("v4-one"), pc.LocalAddr()) @@ -284,17 +248,19 @@ func TestReadFrame_IPv4_EncodesWithUDPEncodeAndStablePeerMAC(t *testing.T) { } func TestReadFrame_IPv6_EncodesWithUDPEncodeAndStablePeerMAC(t *testing.T) { - pc, err := net.ListenPacket("udp", "[::1]:0") - if err != nil { - t.Skip("IPv6 loopback not available:", err) - } - defer pc.Close() + conn, err := net.ListenPacket("udp", "[::1]:0") + require.NoError(t, err) + + pc, err := batchpc.New("udp6", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) peer, err := net.ListenPacket("udp", "[::1]:0") require.NoError(t, err) - defer peer.Close() + t.Cleanup(func() { require.NoError(t, peer.Close()) }) _, err = peer.WriteTo([]byte("v6-one"), pc.LocalAddr()) require.NoError(t, err) @@ -326,15 +292,19 @@ func TestReadFrame_IPv6_EncodesWithUDPEncodeAndStablePeerMAC(t *testing.T) { } func TestReadFrame_BufferTooSmall(t *testing.T) { - pc, err := net.ListenPacket("udp", "127.0.0.1:0") + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + pc, err := batchpc.New("udp4", conn) require.NoError(t, err) - defer pc.Close() + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) require.NoError(t, err) peer, err := net.ListenPacket("udp", "127.0.0.1:0") require.NoError(t, err) - defer peer.Close() + t.Cleanup(func() { require.NoError(t, peer.Close()) }) _, err = peer.WriteTo([]byte("tiny"), pc.LocalAddr()) require.NoError(t, err) @@ -346,3 +316,247 @@ func TestReadFrame_BufferTooSmall(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "destination buffer too small") } + +func TestWriteBatchFrames_IPv4_SendsAll(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + peer, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, peer.Close()) }) + + dst := peer.LocalAddr().(*net.UDPAddr) + + // Build three frames to the same peer. + payloads := [][]byte{[]byte("b1"), []byte("b2"), []byte("b3")} + msgs := make([]batchpc.Message, len(payloads)) + for i := range payloads { + frame := makeIPv4Frame( + tcpip.GetRandMacAddr(), + tcpip.GetRandMacAddr(), + net.IPv4(127, 0, 0, 1), + 12340+uint16(i), + dst.IP, + uint16(dst.Port), + payloads[i], + ) + msgs[i].Buf = frame + } + + n, err := c.WriteBatchFrames(msgs, 0) + require.NoError(t, err) + require.Equal(t, len(msgs), n) + + // Receive all three payloads (order is not important). + _ = peer.SetReadDeadline(time.Now().Add(2 * time.Second)) + got := map[string]int{} + for i := 0; i < len(payloads); i++ { + buf := make([]byte, 64) + ni, _, rerr := peer.ReadFrom(buf) + require.NoError(t, rerr) + got[string(buf[:ni])]++ + } + require.Equal(t, 1, got["b1"]) + require.Equal(t, 1, got["b2"]) + require.Equal(t, 1, got["b3"]) +} + +func TestWriteBatchFrames_Empty_NoOp(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + n, err := c.WriteBatchFrames(nil, 0) + require.NoError(t, err) + require.Equal(t, 0, n) +} + +func TestWriteBatchFrames_InvalidFrameStopsAtIndex(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + peer, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, peer.Close()) }) + + dst := peer.LocalAddr().(*net.UDPAddr) + + // Good, Bad, Good — should fail at index 1 with partial count 1. + good1 := makeIPv4Frame(tcpip.GetRandMacAddr(), tcpip.GetRandMacAddr(), + net.IPv4(127, 0, 0, 1), 1111, dst.IP, uint16(dst.Port), []byte("ok1")) + bad := []byte{0x01, 0x02} // too short → ErrInvalidFrame + good2 := makeIPv4Frame(tcpip.GetRandMacAddr(), tcpip.GetRandMacAddr(), + net.IPv4(127, 0, 0, 1), 2222, dst.IP, uint16(dst.Port), []byte("ok2")) + + msgs := []batchpc.Message{ + {Buf: good1}, + {Buf: bad}, + {Buf: good2}, + } + + n, err := c.WriteBatchFrames(msgs, 0) + require.Error(t, err) + require.True(t, errors.Is(err, l2pc.ErrInvalidFrame)) + require.Equal(t, 1, n, "should report count up to first bad frame") +} + +func TestReadBatchFrames_IPv4_EncodesFramesAndSetsAddr(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + peer, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, peer.Close()) }) + + // Send three datagrams to the adapter. + want := [][]byte{[]byte("r1"), []byte("r2"), []byte("r3")} + for _, w := range want { + _, err := peer.WriteTo(w, pc.LocalAddr()) + require.NoError(t, err) + } + + _ = pc.SetReadDeadline(time.Now().Add(2 * time.Second)) + + // Prepare batch buffers. + msgs := make([]batchpc.Message, len(want)) + for i := range msgs { + msgs[i].Buf = make([]byte, 4096) + } + + n, err := c.ReadBatchFrames(msgs, 0) + require.NoError(t, err) + require.Equal(t, len(want), n) + + // Validate each produced Ethernet+IP+UDP frame and that Addr is set. + got := map[string]int{} + for i := 0; i < n; i++ { + var src tcpip.FullAddress + pl, derr := udp.Decode(msgs[i].Buf, &src, false /* checksum */) + require.NoError(t, derr) + got[string(pl)]++ + require.NotNil(t, msgs[i].Addr) + // Ethernet dst must be local MAC; src MAC should be stable per IP (not strictly checked here). + require.Equal(t, c.LocalMAC(), header.Ethernet(msgs[i].Buf).DestinationAddress()) + } + require.Equal(t, 1, got["r1"]) + require.Equal(t, 1, got["r2"]) + require.Equal(t, 1, got["r3"]) +} + +func TestReadBatchFrames_Empty_NoOp(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + n, err := c.ReadBatchFrames(nil, 0) + require.NoError(t, err) + require.Equal(t, 0, n) +} + +func TestReadBatchFrames_BufferTooSmallAtIndex(t *testing.T) { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + pc, err := batchpc.New("udp4", conn) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, pc.Close()) }) + c, err := l2pc.NewL2PacketConn(pc) + require.NoError(t, err) + + peer, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, peer.Close()) }) + + // Send two datagrams. + _, err = peer.WriteTo([]byte("x1"), pc.LocalAddr()) + require.NoError(t, err) + _, err = peer.WriteTo([]byte("x2"), pc.LocalAddr()) + require.NoError(t, err) + + _ = pc.SetReadDeadline(time.Now().Add(2 * time.Second)) + + // msgs[0] big enough, msgs[1] deliberately too small (< Ethernet+IPv4+UDP). + msgs := []batchpc.Message{ + {Buf: make([]byte, 4096)}, + {Buf: make([]byte, header.EthernetMinimumSize+header.IPv4MinimumSize+header.UDPMinimumSize-1)}, + } + + n, err := c.ReadBatchFrames(msgs, 0) + require.Error(t, err) + require.Contains(t, err.Error(), "destination buffer too small") + require.Equal(t, 1, n, "should process exactly the first frame before failing on index 1") +} + +func makeIPv4Frame(srcMAC, dstMAC tcpip.LinkAddress, srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16, payload []byte) []byte { + buf := make([]byte, udp.PayloadOffsetIPv4+len(payload)) + copy(buf[udp.PayloadOffsetIPv4:], payload) + + src := &tcpip.FullAddress{ + Addr: tcpip.AddrFrom4Slice(srcIP.To4()), + Port: srcPort, + LinkAddr: srcMAC, + } + dst := &tcpip.FullAddress{ + Addr: tcpip.AddrFrom4Slice(dstIP.To4()), + Port: dstPort, + LinkAddr: dstMAC, + } + n, err := udp.Encode(buf, src, dst, len(payload), false /* calc checksum */) + if err != nil { + panic(err) + } + return buf[:n] +} + +func makeIPv6Frame(srcMAC, dstMAC tcpip.LinkAddress, srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16, payload []byte) []byte { + buf := make([]byte, udp.PayloadOffsetIPv6+len(payload)) + copy(buf[udp.PayloadOffsetIPv6:], payload) + + src := &tcpip.FullAddress{ + Addr: tcpip.AddrFrom16Slice(dstTo16(srcIP)), + Port: srcPort, + LinkAddr: srcMAC, + } + dst := &tcpip.FullAddress{ + Addr: tcpip.AddrFrom16Slice(dstTo16(dstIP)), + Port: dstPort, + LinkAddr: dstMAC, + } + n, err := udp.Encode(buf, src, dst, len(payload), false /* calc checksum */) + if err != nil { + panic(err) + } + return buf[:n] +} + +func dstTo16(ip net.IP) []byte { + if ip == nil { + return nil + } + return ip.To16() +} diff --git a/pkg/tunnel/router/icx_netlink_linux.go b/pkg/tunnel/router/icx_netlink_linux.go index fa215bfb..e59f777b 100644 --- a/pkg/tunnel/router/icx_netlink_linux.go +++ b/pkg/tunnel/router/icx_netlink_linux.go @@ -129,7 +129,7 @@ func NewICXNetlinkRouter(opts ...Option) (*ICXNetlinkRouter, error) { return nil, fmt.Errorf("failed to create handler: %w", err) } - ingressFilter, err := filter.Bind(extAddrs...) + ingressFilter, err := filter.Geneve(extAddrs...) if err != nil { _ = tunDev.Close() return nil, fmt.Errorf("failed to create ingress filter: %w", err) diff --git a/pkg/tunnel/router/options.go b/pkg/tunnel/router/options.go index 3329c01d..30d5fc85 100644 --- a/pkg/tunnel/router/options.go +++ b/pkg/tunnel/router/options.go @@ -1,9 +1,9 @@ package router import ( - "net" "net/netip" + "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" "github.com/apoxy-dev/icx" "github.com/dpeckett/network" ) @@ -23,7 +23,7 @@ type routerOptions struct { cksumRecalc bool preserveDefaultGwDsts []netip.Prefix sourcePortHashing bool - pc net.PacketConn + pc batchpc.BatchPacketConn egressGateway bool } @@ -120,9 +120,9 @@ func WithSourcePortHashing(enable bool) Option { } } -// WithPacketConn sets the underlying PacketConn for the ICX router. +// WithPacketConn sets a custom BatchPacketConn for the ICX netstack router. // Only valid for ICX netstack router. -func WithPacketConn(pc net.PacketConn) Option { +func WithPacketConn(pc batchpc.BatchPacketConn) Option { return func(o *routerOptions) { o.pc = pc } From 86ea79ba40f0dd06969f2da7131e24e16c200d8b Mon Sep 17 00:00:00 2001 From: Damian Peckett Date: Fri, 24 Oct 2025 09:55:39 +0200 Subject: [PATCH 2/2] [icxtunnel] handle and bubble up transient errors for increased robustness --- pkg/netstack/icx_network.go | 15 ++- pkg/netstack/icx_network_test.go | 81 +++++++++++---- pkg/tunnel/bifurcate/bifurcate.go | 37 +++++-- pkg/tunnel/bifurcate/bifurcate_test.go | 51 ++++++++- pkg/tunnel/bifurcate/chanpc.go | 137 ++++++++++++++++++++----- 5 files changed, 256 insertions(+), 65 deletions(-) diff --git a/pkg/netstack/icx_network.go b/pkg/netstack/icx_network.go index 112853ee..295ba2b5 100644 --- a/pkg/netstack/icx_network.go +++ b/pkg/netstack/icx_network.go @@ -165,7 +165,6 @@ func (net *ICXNetwork) Close() error { } // Start copies packets to and from netstack and icx. -// Upstream (netstack -> phy) uses batched I/O; downstream remains as before. func (net *ICXNetwork) Start() error { const tickMs = 100 // periodically flush scheduled frames (ToPhy) @@ -260,10 +259,14 @@ func (net *ICXNetwork) Start() error { for i := 0; i < count; i++ { msgs[i] = batch[i].msg } - n, err := net.phy.WriteBatchFrames(msgs, 0) + _, err := net.phy.WriteBatchFrames(msgs, 0) putOwned(batch) if err != nil { - return fmt.Errorf("writing batched phy frames failed after %d/%d: %w", n, count, err) + if errors.Is(err, stdnet.ErrClosed) { + return err + } + slog.Warn("Error writing batched phy frames", slog.Any("error", err)) + continue } } }) @@ -277,7 +280,11 @@ func (net *ICXNetwork) Start() error { n, err := net.phy.ReadFrame(*phyFrame) if err != nil { net.pktPool.Put(phyFrame) - return fmt.Errorf("reading phy frame failed: %w", err) + if errors.Is(err, stdnet.ErrClosed) { + return err + } + slog.Warn("Error reading frame from physical interface", slog.Any("error", err)) + continue } if n == 0 { net.pktPool.Put(phyFrame) diff --git a/pkg/netstack/icx_network_test.go b/pkg/netstack/icx_network_test.go index a99c2e9b..4b614964 100644 --- a/pkg/netstack/icx_network_test.go +++ b/pkg/netstack/icx_network_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/avast/retry-go/v4" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -171,31 +172,64 @@ func TestICXNetwork_Speed(t *testing.T) { Timeout: 30 * time.Second, } + // helper that GETs URL and reads/discards the body, + // returning bytes read. wrapped in retry at callsite. + fetchAndDrain := func(cl *http.Client, url string) (int64, error) { + resp, err := cl.Get(url) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("status: %s", resp.Status) + } + + n, err := io.Copy(io.Discard, resp.Body) + if err != nil { + return 0, err + } + if n != totalBytes { + return 0, fmt.Errorf("unexpected byte count: got %d, want %d", n, totalBytes) + } + return n, nil + } + // Single-stream speed test t.Run("Speed", func(t *testing.T) { url := "http://" + ln.Addr().String() + "/speed" start := time.Now() - resp, err := client.Get(url) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - - n, err := io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - require.NoError(t, err) - require.EqualValues(t, totalBytes, n, "unexpected byte count") + var n int64 + err := retry.Do( + func() error { + readBytes, err := fetchAndDrain(client, url) + if err != nil { + return err + } + n = readBytes + return nil + }, + retry.Attempts(3), + retry.Delay(100*time.Millisecond), + retry.DelayType(retry.FixedDelay), + ) + require.NoError(t, err, "single-stream GET failed even after retries") elapsed := time.Since(start) sec := elapsed.Seconds() + + // bits/s calc mbps := (float64(n) * 8) / 1_000_000 / sec gbps := (float64(n) * 8) / 1_000_000_000 / sec + // bytes/s calc mbpsBytes := (float64(n)) / 1_000_000 / sec t.Logf("Downloaded %d bytes in %s → %.2f MB/s, %.2f Mbit/s (%.2f Gbit/s)", n, elapsed, mbpsBytes, mbps, gbps) }) - // Parallel speed test: four concurrent streams, each 200 MiB + // Parallel speed test: eight concurrent streams, each 100 MiB t.Run("SpeedParallel", func(t *testing.T) { const numStreams = 8 URL := "http://" + ln.Addr().String() + "/speed" @@ -220,21 +254,25 @@ func TestICXNetwork_Speed(t *testing.T) { for i := 0; i < numStreams; i++ { g.Go(func() error { - resp, err := parallelClient.Get(URL) + var gotBytes int64 + err := retry.Do( + func() error { + readBytes, err := fetchAndDrain(parallelClient, URL) + if err != nil { + return err + } + gotBytes = readBytes + return nil + }, + retry.Attempts(3), + retry.Delay(100*time.Millisecond), + retry.DelayType(retry.FixedDelay), + ) if err != nil { return err } - if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - return fmt.Errorf("status: %s", resp.Status) - } - n, err := io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - if err != nil { - return err - } - if n != totalBytes { - return fmt.Errorf("unexpected byte count: got %d, want %d", n, totalBytes) + if gotBytes != totalBytes { + return fmt.Errorf("unexpected byte count post-retry: got %d, want %d", gotBytes, totalBytes) } return nil }) @@ -244,6 +282,7 @@ func TestICXNetwork_Speed(t *testing.T) { totalRead := int64(numStreams) * totalBytes elapsed := time.Since(start) sec := elapsed.Seconds() + mbps := (float64(totalRead) * 8) / 1_000_000 / sec gbps := (float64(totalRead) * 8) / 1_000_000_000 / sec mbpsBytes := (float64(totalRead)) / 1_000_000 / sec diff --git a/pkg/tunnel/bifurcate/bifurcate.go b/pkg/tunnel/bifurcate/bifurcate.go index 5b3e9c20..b5001693 100644 --- a/pkg/tunnel/bifurcate/bifurcate.go +++ b/pkg/tunnel/bifurcate/bifurcate.go @@ -1,6 +1,9 @@ package bifurcate import ( + "errors" + "log/slog" + "net" "sync" "github.com/apoxy-dev/icx/geneve" @@ -17,8 +20,9 @@ var messagePool = sync.Pool{ // Bifurcate splits incoming packets from `pc` into geneve and other channels. func Bifurcate(pc batchpc.BatchPacketConn) (batchpc.BatchPacketConn, batchpc.BatchPacketConn) { - geneveConn := newChanPacketConn(pc) - otherConn := newChanPacketConn(pc) + var closeConnOnce sync.Once + geneveConn := newChanPacketConn(pc, &closeConnOnce) + otherConn := newChanPacketConn(pc, &closeConnOnce) // Local copies we can nil out when a side is closed. geneveCh := geneveConn.ch @@ -53,17 +57,31 @@ func Bifurcate(pc batchpc.BatchPacketConn) (batchpc.BatchPacketConn, batchpc.Bat n, err := pc.ReadBatch(msgs, 0) if err != nil { - // Return any outstanding pooled messages. + // Recycle any pooled messages we haven't handed off. for i := 0; i < len(pm); i++ { if pm[i] != nil { messagePool.Put(pm[i]) pm[i] = nil } } - _ = geneveConn.Close() - _ = otherConn.Close() - return + + // Bubble the error up to each chanPacketConn. + geneveConn.setErr(err) + otherConn.setErr(err) + + // Only close+exit if this is a permanent close. + if errors.Is(err, net.ErrClosed) { + _ = geneveConn.Close() + _ = otherConn.Close() + return + } + + slog.Warn("Error reading batch from underlying connection", slog.Any("error", err)) + + // Transient error: keep the bifurcator alive. + continue } + if n == 0 { continue } @@ -74,7 +92,7 @@ func Bifurcate(pc batchpc.BatchPacketConn) (batchpc.BatchPacketConn, batchpc.Bat for i := 0; i < n; i++ { m := pm[i] - // msgs[i].Buf has been resized by underlying BatchPacketConn ReadBatch. + // msgs[i].Buf may have been resized by underlying BatchPacketConn ReadBatch. m.Buf = msgs[i].Buf m.Addr = msgs[i].Addr @@ -101,6 +119,7 @@ func Bifurcate(pc batchpc.BatchPacketConn) (batchpc.BatchPacketConn, batchpc.Bat for _, m := range batch { messagePool.Put(m) } + close(ch) ch = nil closed = nil } @@ -122,12 +141,12 @@ func isGeneve(b []byte) bool { return false } - // Only Geneve version 0 is defined + // Only Geneve version 0 is defined. if hdr.Version != 0 { return false } - // Check for valid protocol types (IPv4 or IPv6) or EtherType unknown (out-of-band messages). + // Check for valid protocol types (IPv4 or IPv6) or EtherType 0 (mgmt / oob). if hdr.ProtocolType != uint16(header.IPv4ProtocolNumber) && hdr.ProtocolType != uint16(header.IPv6ProtocolNumber) && hdr.ProtocolType != 0 { diff --git a/pkg/tunnel/bifurcate/bifurcate_test.go b/pkg/tunnel/bifurcate/bifurcate_test.go index 8ca2e8aa..7b0b04d3 100644 --- a/pkg/tunnel/bifurcate/bifurcate_test.go +++ b/pkg/tunnel/bifurcate/bifurcate_test.go @@ -163,13 +163,14 @@ func TestBifurcate_WriteBatchForwardsToUnderlying(t *testing.T) { } } -func TestBifurcate_ClosesBothOnReadError(t *testing.T) { +func TestBifurcate_ClosesBothOnUnderlyingClose(t *testing.T) { mockConn := newMockBatchPacketConn() - // Simulate read error: close the queue so next read fails. - close(mockConn.readQueue) geneveConn, otherConn := bifurcate.Bifurcate(mockConn) + // close the underlying connection + _ = mockConn.Close() + // Give the goroutine a breath to observe the close. time.Sleep(50 * time.Millisecond) @@ -181,6 +182,42 @@ func TestBifurcate_ClosesBothOnReadError(t *testing.T) { require.ErrorIs(t, err, net.ErrClosed) } +func TestBifurcate_BubblesTransientErrorAndContinues(t *testing.T) { + t.Helper() + + mockConn := newMockBatchPacketConn() + transientErr := errors.New("temporary I/O error") + + remote := &net.UDPAddr{IP: net.IPv4(10, 9, 8, 7), Port: 31337} + genevePkt := createGenevePacket(t) + + // First queued result is a transient error (channel stays open). + mockConn.readQueue <- readResult{ + err: transientErr, + } + // Second queued result is a valid Geneve packet. + mockConn.enqueue(genevePkt, remote) + + geneveConn, _ := bifurcate.Bifurcate(mockConn) + + // Give the bifurcator goroutine a moment to observe both queued results: + // 1) record transientErr via setErr(...) + // 2) enqueue the good packet batch onto geneveConn.ch + time.Sleep(50 * time.Millisecond) + + buf := make([]byte, 1024) + + // First read should surface the transient error that was bubbled up. + _, _, err := geneveConn.ReadFrom(buf) + require.ErrorIs(t, err, transientErr) + + // Second read should now succeed and return the real packet. + n, addr, err := geneveConn.ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, remote.String(), addr.String()) + require.True(t, bytes.Equal(buf[:n], genevePkt), "expected geneve packet after transient error") +} + func createGenevePacket(t *testing.T) []byte { h := geneve.Header{ Version: 0, @@ -278,14 +315,17 @@ func (pc *mockBatchPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int if len(msgs) == 0 { return 0, nil } + // First result: block result, ok := <-pc.readQueue if !ok { - return 0, errors.New("mock read closed") + // underlying permanently closed + return 0, net.ErrClosed } if result.err != nil { return 0, result.err } + n0 := copy(msgs[0].Buf, result.data) msgs[0].Buf = msgs[0].Buf[:n0] msgs[0].Addr = result.addr @@ -296,7 +336,7 @@ func (pc *mockBatchPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int select { case rr, ok := <-pc.readQueue: if !ok { - return n, errors.New("mock read closed") + return n, net.ErrClosed } if rr.err != nil { return n, rr.err @@ -309,6 +349,7 @@ func (pc *mockBatchPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int return n, nil } } + return n, nil } diff --git a/pkg/tunnel/bifurcate/chanpc.go b/pkg/tunnel/bifurcate/chanpc.go index c350e462..d993d251 100644 --- a/pkg/tunnel/bifurcate/chanpc.go +++ b/pkg/tunnel/bifurcate/chanpc.go @@ -2,26 +2,34 @@ package bifurcate import ( "net" + "sync" "time" "github.com/apoxy-dev/apoxy/pkg/tunnel/batchpc" ) type chanPacketConn struct { - pc batchpc.BatchPacketConn + pc batchpc.BatchPacketConn + closeConnOnce *sync.Once // Incoming batches from the bifurcator goroutine. - ch chan []*batchpc.Message - closed chan struct{} + ch chan []*batchpc.Message + closedOnce sync.Once + closed chan struct{} // Locally pending batch from the last receive (not yet fully consumed). + pendingMu sync.Mutex pending []*batchpc.Message pendingIndex int + // Last transient error to be surfaced on next Read/ReadBatch. + errMu sync.Mutex + lastErr error } -func newChanPacketConn(pc batchpc.BatchPacketConn) *chanPacketConn { +func newChanPacketConn(pc batchpc.BatchPacketConn, closeConnOnce *sync.Once) *chanPacketConn { return &chanPacketConn{ - ch: make(chan []*batchpc.Message, 1024), - pc: pc, - closed: make(chan struct{}), + ch: make(chan []*batchpc.Message, 1024), + pc: pc, + closeConnOnce: closeConnOnce, + closed: make(chan struct{}), } } @@ -29,6 +37,7 @@ func (pc *chanPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { if err := pc.ensurePendingBlocking(); err != nil { return 0, nil, err } + msg := pc.popOne() defer messagePool.Put(msg) @@ -41,13 +50,15 @@ func (pc *chanPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { } func (pc *chanPacketConn) Close() error { - select { - case <-pc.closed: - return nil - default: + pc.closedOnce.Do(func() { close(pc.closed) - return nil - } + }) + + var err error + pc.closeConnOnce.Do(func() { + err = pc.pc.Close() + }) + return err } func (pc *chanPacketConn) LocalAddr() net.Addr { @@ -77,10 +88,13 @@ func (pc *chanPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int, err return 0, err } - // 2) Fill from pending, then non-blocking drain of further batches. fill := func() { - for n < len(msgs) && len(pc.pending) > 0 { + for n < len(msgs) { msg := pc.popOne() + if msg == nil { + // nothing pending anymore + break + } copied := copy(msgs[n].Buf, msg.Buf) msgs[n].Buf = msgs[n].Buf[:copied] msgs[n].Addr = msg.Addr @@ -89,8 +103,10 @@ func (pc *chanPacketConn) ReadBatch(msgs []batchpc.Message, flags int) (int, err } } - fill() // consume current pending + // consume current pending + fill() + // 2) Then non-blocking drain of further batches. for n < len(msgs) { if !pc.tryFillPendingNonBlocking() { break @@ -105,8 +121,28 @@ func (pc *chanPacketConn) WriteBatch(msgs []batchpc.Message, flags int) (int, er return pc.pc.WriteBatch(msgs, flags) } -// popOne pulls one message from pending; assumes pending not empty. +// pendingLenLocked returns len(pending). pendingMu MUST be held by caller. +func (pc *chanPacketConn) pendingLenLocked() int { + return len(pc.pending) +} + +// setPendingLocked sets the pending batch + resets index. pendingMu MUST be held. +func (pc *chanPacketConn) setPendingLocked(batch []*batchpc.Message) { + pc.pending = batch + pc.pendingIndex = 0 +} + +// popOne pulls one message from pending. +// Returns nil if pending is empty. +// Takes the lock internally. func (pc *chanPacketConn) popOne() *batchpc.Message { + pc.pendingMu.Lock() + defer pc.pendingMu.Unlock() + + if len(pc.pending) == 0 { + return nil + } + m := pc.pending[pc.pendingIndex] pc.pendingIndex++ if pc.pendingIndex >= len(pc.pending) { @@ -117,30 +153,79 @@ func (pc *chanPacketConn) popOne() *batchpc.Message { return m } +// ensurePendingBlocking guarantees there's at least one message in pending, +// blocking on pc.ch if needed. Surfaces transient errors first. func (pc *chanPacketConn) ensurePendingBlocking() error { - if len(pc.pending) > 0 { + // Fast path: already have pending locally. + pc.pendingMu.Lock() + if pc.pendingLenLocked() > 0 { + pc.pendingMu.Unlock() return nil } + pc.pendingMu.Unlock() + + // Check if there's a transient error waiting to be reported. + if err := pc.takeErr(); err != nil { + return err + } + select { - case batch := <-pc.ch: - pc.pending = batch - pc.pendingIndex = 0 + case batch, ok := <-pc.ch: + if !ok { + // ch closed -> treat as connection closed + return net.ErrClosed + } + pc.pendingMu.Lock() + pc.setPendingLocked(batch) + pc.pendingMu.Unlock() return nil case <-pc.closed: return net.ErrClosed } } +// tryFillPendingNonBlocking tries to pull a new batch into pending without blocking. +// Returns true if pending now has data. func (pc *chanPacketConn) tryFillPendingNonBlocking() bool { - if len(pc.pending) > 0 { + // Check fast path first. + pc.pendingMu.Lock() + if pc.pendingLenLocked() > 0 { + pc.pendingMu.Unlock() return true } + pc.pendingMu.Unlock() + select { - case batch := <-pc.ch: - pc.pending = batch - pc.pendingIndex = 0 - return true + case batch, ok := <-pc.ch: + if !ok { + return false + } + pc.pendingMu.Lock() + pc.setPendingLocked(batch) + hasData := pc.pendingLenLocked() > 0 + pc.pendingMu.Unlock() + return hasData default: return false } } + +// setErr records a transient error to be surfaced on the next Read/ReadBatch call. +func (pc *chanPacketConn) setErr(err error) { + if err == nil { + return + } + pc.errMu.Lock() + pc.lastErr = err + pc.errMu.Unlock() +} + +// takeErr returns (and clears) the currently stored transient error. +// If there's no pending transient error it returns nil. +func (pc *chanPacketConn) takeErr() error { + pc.errMu.Lock() + defer pc.errMu.Unlock() + err := pc.lastErr + pc.lastErr = nil + return err +}