From dbd25eef6a7dc6904663c253ea5369b23c89bcf1 Mon Sep 17 00:00:00 2001 From: Daniel Lavrushin Date: Mon, 9 Mar 2026 17:20:17 +0100 Subject: [PATCH] Implement TUN-based packet processing engine - Added a new engine package with PacketVerdict type and Engine interface. - Introduced nfq package for engine-agnostic packet processing logic. - Created tun package to manage TUN device creation, routing, and packet processing. - Implemented routeManager for setting up and tearing down routing rules for TUN mode. - Enhanced Engine to handle TUN device operations and integrate with existing NFQUEUE logic. - Added methods for reading packets from TUN device and forwarding them based on verdicts. --- src/config/bind.go | 7 - src/config/config.go | 6 + src/config/methods.go | 13 ++ src/config/types.go | 10 + src/engine/engine.go | 18 ++ src/main.go | 100 +++++---- src/nfq/dns.go | 49 +--- src/nfq/inc.go | 9 +- src/nfq/nfq.go | 503 ++---------------------------------------- src/nfq/process.go | 444 +++++++++++++++++++++++++++++++++++++ src/tun/device.go | 57 +++++ src/tun/route.go | 154 +++++++++++++ src/tun/tun.go | 184 +++++++++++++++ 13 files changed, 982 insertions(+), 572 deletions(-) create mode 100644 src/engine/engine.go create mode 100644 src/nfq/process.go create mode 100644 src/tun/device.go create mode 100644 src/tun/route.go create mode 100644 src/tun/tun.go diff --git a/src/config/bind.go b/src/config/bind.go index af995275..f8ac50d9 100644 --- a/src/config/bind.go +++ b/src/config/bind.go @@ -6,13 +6,6 @@ func (c *Config) BindFlags(cmd *cobra.Command) { // Config path cmd.Flags().StringVar(&c.ConfigPath, "config", c.ConfigPath, "Path to config file") - // Queue configuration - cmd.Flags().IntVar(&c.Queue.StartNum, "queue-num", c.Queue.StartNum, "Netfilter queue number") - cmd.Flags().IntVar(&c.Queue.Threads, "threads", c.Queue.Threads, "Number of worker threads") - cmd.Flags().UintVar(&c.Queue.Mark, "mark", c.Queue.Mark, "Packet mark value (default 32768)") - cmd.Flags().BoolVar(&c.Queue.IPv4Enabled, "ipv4", c.Queue.IPv4Enabled, "Enable IPv4 processing") - cmd.Flags().BoolVar(&c.Queue.IPv6Enabled, "ipv6", c.Queue.IPv6Enabled, "Enable IPv6 processing") - // System configuration cmd.Flags().IntVar(&c.System.Tables.MonitorInterval, "tables-monitor-interval", c.System.Tables.MonitorInterval, "Tables monitor interval in seconds (default 10, 0 to disable)") cmd.Flags().BoolVar(&c.System.Tables.SkipSetup, "skip-tables", c.System.Tables.SkipSetup, "Skip iptables/nftables setup on startup") diff --git a/src/config/config.go b/src/config/config.go index 85b316db..bd1fad5c 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -144,6 +144,7 @@ var DefaultConfig = Config{ ConfigPath: "", Queue: QueueConfig{ + Mode: "nfqueue", StartNum: 537, Mark: 1 << 15, Threads: 4, @@ -163,6 +164,11 @@ var DefaultConfig = Config{ Enabled: false, Size: 88, }, + TUN: TUNConfig{ + DeviceName: "b4tun0", + Address: "10.255.0.1/30", + RouteTable: 100, + }, }, Sets: []*SetConfig{}, diff --git a/src/config/methods.go b/src/config/methods.go index f2d03d1c..b3d5ff73 100644 --- a/src/config/methods.go +++ b/src/config/methods.go @@ -192,6 +192,19 @@ func (c *Config) Validate() error { return fmt.Errorf("threads must be at least 1") } + if c.Queue.Mode == "" { + c.Queue.Mode = "nfqueue" + } + if c.Queue.Mode != "nfqueue" && c.Queue.Mode != "tun" { + return fmt.Errorf("queue mode must be 'nfqueue' or 'tun'") + } + + if c.Queue.Mode == "tun" { + if c.Queue.TUN.OutInterface == "" { + return fmt.Errorf("tun out_interface is required in TUN mode (e.g. eth0, wan0)") + } + } + if c.Queue.StartNum < 0 || c.Queue.StartNum > 65535 { return fmt.Errorf("queue-num must be between 0 and 65535") } diff --git a/src/config/types.go b/src/config/types.go index f537f327..542c4db6 100644 --- a/src/config/types.go +++ b/src/config/types.go @@ -26,6 +26,7 @@ type ApiConfig struct { } type QueueConfig struct { + Mode string `json:"mode" bson:"mode"` // "nfqueue" (default) or "tun" StartNum int `json:"start_num" bson:"start_num"` Threads int `json:"threads" bson:"threads"` Mark uint `json:"mark" bson:"mark"` @@ -36,6 +37,15 @@ type QueueConfig struct { Interfaces []string `json:"interfaces" bson:"interfaces"` Devices DevicesConfig `json:"devices" bson:"devices"` MSSClamp MSSClampConfig `json:"mss_clamp" bson:"mss_clamp"` + TUN TUNConfig `json:"tun" bson:"tun"` +} + +type TUNConfig struct { + DeviceName string `json:"device_name" bson:"device_name"` // TUN device name, default: "b4tun0" + Address string `json:"address" bson:"address"` // TUN device address, default: "10.255.0.1/30" + OutInterface string `json:"out_interface" bson:"out_interface"` // Real outbound interface, e.g. "eth0", "wan0" + OutGateway string `json:"out_gateway" bson:"out_gateway"` // Real gateway IP, auto-detected if empty + RouteTable int `json:"route_table" bson:"route_table"` // Policy routing table number, default: 100 } type DevicesConfig struct { diff --git a/src/engine/engine.go b/src/engine/engine.go new file mode 100644 index 00000000..2f8a60e7 --- /dev/null +++ b/src/engine/engine.go @@ -0,0 +1,18 @@ +package engine + +// PacketVerdict tells the packet source what to do with the original packet. +type PacketVerdict int + +const ( + // VerdictAccept forwards the packet unchanged. + VerdictAccept PacketVerdict = iota + // VerdictDrop suppresses the original packet (modified copies are already sent by the handler). + VerdictDrop +) + +// Engine is the interface for a packet processing backend. +// Both NFQUEUE and TUN implement this interface. +type Engine interface { + Start() error + Stop() +} diff --git a/src/main.go b/src/main.go index a9b606d9..235e2f98 100644 --- a/src/main.go +++ b/src/main.go @@ -22,6 +22,7 @@ import ( "github.com/daniellavrushin/b4/quic" "github.com/daniellavrushin/b4/socks5" "github.com/daniellavrushin/b4/tables" + b4tun "github.com/daniellavrushin/b4/tun" "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -123,39 +124,58 @@ func runB4(cmd *cobra.Command, args []string) error { log.Infof("Loaded targets: %d domains, %d IPs across %d sets", totalDomains, totalIps, len(cfg.Sets)) - // Setup iptables/nftables rules - if !cfg.System.Tables.SkipSetup { - log.Tracef("Clearing existing iptables/nftables rules") - tables.ClearRules(&cfg) + pool := nfq.NewPool(&cfg) + + var tunEngine *b4tun.Engine + var tablesMonitor *tables.Monitor - log.Tracef("Adding tables rules") - if err := tables.AddRules(&cfg); err != nil { - metrics.RecordEvent("error", fmt.Sprintf("Failed to add tables rules: %v", err)) - return fmt.Errorf("failed to add tables rules: %w", err) + if cfg.Queue.Mode == "tun" { + // TUN mode: no iptables/nftables needed + log.Infof("Starting TUN engine (device: %s, out: %s, threads: %d)", + cfg.Queue.TUN.DeviceName, cfg.Queue.TUN.OutInterface, cfg.Queue.Threads) + metrics.TablesStatus = "tun" + + tunEngine = b4tun.NewEngine(&cfg, pool) + if err := tunEngine.Start(); err != nil { + metrics.RecordEvent("error", fmt.Sprintf("TUN engine start failed: %v", err)) + metrics.NFQueueStatus = "error" + return fmt.Errorf("TUN engine start failed: %w", err) } - metrics.RecordEvent("info", "Tables rules configured successfully") + + metrics.RecordEvent("info", fmt.Sprintf("TUN engine started with %d threads", cfg.Queue.Threads)) + metrics.NFQueueStatus = "active (tun)" } else { - log.Infof("Skipping tables setup (--skip-tables)") - metrics.TablesStatus = "skipped" - } + // NFQUEUE mode: setup iptables/nftables rules + if !cfg.System.Tables.SkipSetup { + log.Tracef("Clearing existing iptables/nftables rules") + tables.ClearRules(&cfg) + + log.Tracef("Adding tables rules") + if err := tables.AddRules(&cfg); err != nil { + metrics.RecordEvent("error", fmt.Sprintf("Failed to add tables rules: %v", err)) + return fmt.Errorf("failed to add tables rules: %w", err) + } + metrics.RecordEvent("info", "Tables rules configured successfully") + } else { + log.Infof("Skipping tables setup (--skip-tables)") + metrics.TablesStatus = "skipped" + } - // Start netfilter queue pool - log.Infof("Starting netfilter queue pool (queue: %d, threads: %d)", cfg.Queue.StartNum, cfg.Queue.Threads) - pool := nfq.NewPool(&cfg) - if err := pool.Start(); err != nil { - metrics.RecordEvent("error", fmt.Sprintf("NFQueue start failed: %v", err)) - metrics.NFQueueStatus = "error" - return fmt.Errorf("netfilter queue start failed: %w", err) - } + log.Infof("Starting netfilter queue pool (queue: %d, threads: %d)", cfg.Queue.StartNum, cfg.Queue.Threads) + if err := pool.Start(); err != nil { + metrics.RecordEvent("error", fmt.Sprintf("NFQueue start failed: %v", err)) + metrics.NFQueueStatus = "error" + return fmt.Errorf("netfilter queue start failed: %w", err) + } - metrics.RecordEvent("info", fmt.Sprintf("NFQueue started with %d threads", cfg.Queue.Threads)) - metrics.NFQueueStatus = "active" + metrics.RecordEvent("info", fmt.Sprintf("NFQueue started with %d threads", cfg.Queue.Threads)) + metrics.NFQueueStatus = "active" - // Start tables monitor to handle rule restoration if system wipes them - var tablesMonitor *tables.Monitor - if !cfg.System.Tables.SkipSetup && cfg.System.Tables.MonitorInterval > 0 { - tablesMonitor = tables.NewMonitor(&cfg) - tablesMonitor.Start() + // Start tables monitor to handle rule restoration if system wipes them + if !cfg.System.Tables.SkipSetup && cfg.System.Tables.MonitorInterval > 0 { + tablesMonitor = tables.NewMonitor(&cfg) + tablesMonitor.Start() + } } // Start internal web server if configured @@ -185,10 +205,10 @@ func runB4(cmd *cobra.Command, args []string) error { metrics.RecordEvent("info", fmt.Sprintf("Shutdown initiated by signal: %v", sig)) // Perform graceful shutdown with timeout - return gracefulShutdown(&cfg, pool, httpServer, socks5Server, metrics) + return gracefulShutdown(&cfg, pool, tunEngine, httpServer, socks5Server, metrics) } -func gracefulShutdown(cfg *config.Config, pool *nfq.Pool, httpServer *http.Server, socks5Server *socks5.Server, metrics *handler.MetricsCollector) error { +func gracefulShutdown(cfg *config.Config, pool *nfq.Pool, tunEngine *b4tun.Engine, httpServer *http.Server, socks5Server *socks5.Server, metrics *handler.MetricsCollector) error { // Create shutdown context with timeout shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -230,33 +250,37 @@ func gracefulShutdown(cfg *config.Config, pool *nfq.Pool, httpServer *http.Serve log.Infof("Shutting down WebSocket connections...") b4http.Shutdown() - // Stop NFQueue pool + // Stop packet engine wg.Add(1) go func() { defer wg.Done() - log.Infof("Stopping netfilter queue pool...") metrics.NFQueueStatus = "stopping" - // Use a goroutine with timeout for pool.Stop() stopDone := make(chan struct{}) go func() { - pool.Stop() + if tunEngine != nil { + log.Infof("Stopping TUN engine...") + tunEngine.Stop() + } else { + log.Infof("Stopping netfilter queue pool...") + pool.Stop() + } close(stopDone) }() select { case <-stopDone: - log.Infof("Netfilter queue pool stopped") + log.Infof("Packet engine stopped") case <-shutdownCtx.Done(): - log.Errorf("Netfilter queue pool stop timed out") - shutdownErrors <- fmt.Errorf("NFQueue stop timeout") + log.Errorf("Packet engine stop timed out") + shutdownErrors <- fmt.Errorf("engine stop timeout") } quic.Shutdown() }() - // Clean up iptables/nftables rules - if !cfg.System.Tables.SkipSetup { + // Clean up iptables/nftables rules (only in nfqueue mode) + if tunEngine == nil && !cfg.System.Tables.SkipSetup { wg.Add(1) go func() { defer wg.Done() diff --git a/src/nfq/dns.go b/src/nfq/dns.go index 534cc6b0..6da0cae0 100644 --- a/src/nfq/dns.go +++ b/src/nfq/dns.go @@ -6,12 +6,12 @@ import ( "github.com/daniellavrushin/b4/config" "github.com/daniellavrushin/b4/dns" + "github.com/daniellavrushin/b4/engine" "github.com/daniellavrushin/b4/log" "github.com/daniellavrushin/b4/sock" - "github.com/florianl/go-nfqueue" ) -func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, payload []byte, raw []byte, ihl int, id uint32, srcMac string) int { +func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, payload []byte, raw []byte, ihl int, srcMac string) engine.PacketVerdict { if dport == 53 { domain, ok := dns.ParseQueryDomain(payload) @@ -21,19 +21,13 @@ func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, pa targetIP := net.ParseIP(set.DNS.TargetDNS) if targetIP == nil { - if err := w.q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } if ipVersion == IPv4 { targetDNS := targetIP.To4() if targetDNS == nil { - if err := w.q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } originalDst := make(net.IP, 4) @@ -49,27 +43,18 @@ func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, pa } else { _ = w.sock.SendIPv4(raw, targetDNS) } - if err := w.q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } log.Infof("DNS redirect: %s -> %s (set: %s)", domain, set.DNS.TargetDNS, set.Name) - return 0 + return engine.VerdictDrop } else { cfg := w.getConfig() if !cfg.Queue.IPv6Enabled { - if err := w.q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } targetDNS := targetIP.To16() if targetDNS == nil { - if err := w.q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } originalDst := make(net.IP, 16) @@ -84,11 +69,8 @@ func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, pa } else { _ = w.sock.SendIPv6(raw, targetDNS) } - if err := w.q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } log.Infof("DNS redirect (IPv6): %s -> %s (set: %s)", domain, set.DNS.TargetDNS, set.Name) - return 0 + return engine.VerdictDrop } } } @@ -102,10 +84,7 @@ func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, pa sock.FixUDPChecksum(raw, ihl) dns.DnsNATDelete(net.IP(raw[16:20]), dport) _ = w.sock.SendIPv4(raw, net.IP(raw[16:20])) - if err := w.q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictDrop } } else { cfg := w.getConfig() @@ -115,19 +94,13 @@ func (w *Worker) processDnsPacket(ipVersion byte, sport uint16, dport uint16, pa sock.FixUDPChecksumV6(raw) dns.DnsNATDelete(net.IP(raw[24:40]), dport) _ = w.sock.SendIPv6(raw, net.IP(raw[24:40])) - if err := w.q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictDrop } } } } - if err := w.q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } func (w *Worker) sendFragmentedDNSQueryV4(cfg *config.SetConfig, raw []byte, ihl int, dst net.IP) { diff --git a/src/nfq/inc.go b/src/nfq/inc.go index e95e4445..ad3dc006 100644 --- a/src/nfq/inc.go +++ b/src/nfq/inc.go @@ -7,14 +7,14 @@ import ( "time" "github.com/daniellavrushin/b4/config" + "github.com/daniellavrushin/b4/engine" "github.com/daniellavrushin/b4/log" "github.com/daniellavrushin/b4/sock" - "github.com/florianl/go-nfqueue" ) var corruptionStrategies = []string{"badsum", "badseq", "badack", "all"} -func (w *Worker) HandleIncoming(q *nfqueue.Nfqueue, id uint32, v byte, raw []byte, ihl int, src net.IP, dstStr string, dport uint16, srcStr string, sport uint16, payload []byte) int { +func (w *Worker) HandleIncoming(v byte, raw []byte, ihl int, src net.IP, dstStr string, dport uint16, srcStr string, sport uint16, payload []byte) engine.PacketVerdict { incomingSet := connState.GetSetForIncoming(dstStr, dport, srcStr, sport) if incomingSet != nil && incomingSet.TCP.Incoming.Mode != config.ConfigOff { @@ -61,10 +61,7 @@ func (w *Worker) HandleIncoming(q *nfqueue.Nfqueue, id uint32, v byte, raw []byt } } - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to accept incoming packet %d: %v", id, err) - } - return 0 + return engine.VerdictAccept } func (w *Worker) applyCorruption(fake []byte, ihl int, strategy string) { diff --git a/src/nfq/nfq.go b/src/nfq/nfq.go index 7eedddbc..f12fb9a6 100644 --- a/src/nfq/nfq.go +++ b/src/nfq/nfq.go @@ -3,7 +3,6 @@ package nfq import ( "encoding/binary" "errors" - "fmt" "net" "os" "strings" @@ -11,15 +10,12 @@ import ( "syscall" "time" - "github.com/daniellavrushin/b4/capture" "github.com/daniellavrushin/b4/config" + "github.com/daniellavrushin/b4/engine" "github.com/daniellavrushin/b4/log" "github.com/daniellavrushin/b4/metrics" "github.com/daniellavrushin/b4/quic" - "github.com/daniellavrushin/b4/sni" "github.com/daniellavrushin/b4/sock" - "github.com/daniellavrushin/b4/stun" - "github.com/daniellavrushin/b4/utils" "github.com/florianl/go-nfqueue" ) @@ -56,10 +52,6 @@ func (w *Worker) Start() error { log.Tracef("NFQ bound pid=%d queue=%d", pid, w.qnum) defer w.wg.Done() _ = q.RegisterWithErrorFunc(w.ctx, func(a nfqueue.Attribute) int { - cfg := w.getConfig() - var set *config.SetConfig - - matcher := w.getMatcher() id := *a.PacketID if a.Mark != nil && *a.Mark == uint32(mark) { @@ -82,8 +74,6 @@ func (w *Worker) Start() error { default: } - atomic.AddUint64(&w.packetsProcessed, 1) - if a.PacketID == nil || a.Payload == nil || len(*a.Payload) == 0 { if a.PacketID != nil && q != nil { if err := q.SetVerdict(*a.PacketID, nfqueue.NfAccept); err != nil { @@ -94,483 +84,17 @@ func (w *Worker) Start() error { } raw := *a.Payload - v := raw[0] >> 4 - if v != IPv4 && v != IPv6 { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - var proto uint8 - var src, dst net.IP - var ihl int - if v == IPv4 { - if len(raw) < 20 { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - ihl = int(raw[0]&0x0f) * 4 - if len(raw) < ihl { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - - fragOffset := binary.BigEndian.Uint16(raw[6:8]) & 0x1FFF - moreFragments := (binary.BigEndian.Uint16(raw[6:8]) & 0x2000) != 0 - - if fragOffset != 0 || moreFragments { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to accept fragmented IPv4 packet %d: %v", id, err) - } - return 0 - } - - proto = raw[9] - src = net.IP(raw[12:16]) - dst = net.IP(raw[16:20]) - - } else { - if len(raw) < IPv6HeaderLen { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - ihl = IPv6HeaderLen - nextHeader := raw[6] - offset := 40 - - for { - switch nextHeader { - case 0, 43, 60: - if len(raw) < offset+2 { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - nextHeader = raw[offset] - hdrLen := int(raw[offset+1])*8 + 8 - offset += hdrLen - case 44: - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to accept fragmented IPv6 packet %d: %v", id, err) - } - return 0 - default: - goto done - } - } - done: - proto = nextHeader - ihl = offset - src = net.IP(raw[8:24]) - dst = net.IP(raw[24:40]) - } - - if src.IsLoopback() || dst.IsLoopback() { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - srcStr := src.String() - dstStr := dst.String() - - srcMac := w.getMacByIp(srcStr) - - matched, st := matcher.MatchIPWithSource(dst, srcMac) - if matched { - set = st - } - - if proto == 6 && len(raw) >= ihl+TCPHeaderMinLen { - tcp := raw[ihl:] - if len(tcp) < TCPHeaderMinLen { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - datOff := int((tcp[12]>>4)&0x0f) * 4 - if len(tcp) < datOff { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - payload := tcp[datOff:] - sport := binary.BigEndian.Uint16(tcp[0:2]) - dport := binary.BigEndian.Uint16(tcp[2:4]) - - if cfg.IsTCPPort(sport) { - return w.HandleIncoming(q, id, v, raw, ihl, src, dstStr, dport, srcStr, sport, payload) - } - - // If IP matched but set has a port filter, verify port matches (AND logic) - if matched && !set.MatchesTCPDPort(dport) { - matched = false - set = nil - } - - // If IP matching didn't find a set, try TCP port-based set matching - if !matched && cfg.IsTCPPort(dport) { - if portMatched, portSet := matcher.MatchTCPPort(dport); portMatched { - matched = true - set = portSet - } - } - - // Packet duplication path: duplicate ALL outgoing TCP packets on configured ports - // without TLS/SNI parsing. Bypasses DPI evasion entirely. - if matched && cfg.IsTCPPort(dport) && set.TCP.Duplicate.Enabled && set.TCP.Duplicate.Count > 0 { - log.Tracef("TCP duplicate to %s:%d (%d copies, set: %s)", dstStr, dport, set.TCP.Duplicate.Count, set.Name) - - m := metrics.GetMetricsCollector() - m.RecordConnection("TCP-DUP", "", srcStr, dstStr, true, srcMac, set.Name) - m.RecordPacket(uint64(len(raw))) - - if !log.IsDiscoveryActive() { - log.Infof(",TCP-DUP,,,%s:%d,%s,%s:%d,%s", srcStr, sport, set.Name, dstStr, dport, srcMac) - } - - if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - return 0 - } - - for i := 0; i < set.TCP.Duplicate.Count; i++ { - if v == IPv4 { - _ = w.sock.SendIPv4(raw, dst) - } else { - _ = w.sock.SendIPv6(raw, dst) - } - } - return 0 - } - - tcpFlags := tcp[13] - isSyn := (tcpFlags & 0x02) != 0 - isAck := (tcpFlags & 0x10) != 0 - isRst := (tcpFlags & 0x04) != 0 - if isRst && cfg.IsTCPPort(dport) { - log.Tracef("RST received from %s:%d", dstStr, dport) - } - - if isSyn && !isAck && cfg.IsTCPPort(dport) && matched && !set.TCP.Duplicate.Enabled { - log.Tracef("TCP SYN to %s:%d (set: %s)", dstStr, dport, set.Name) - - metrics := metrics.GetMetricsCollector() - metrics.RecordConnection("TCP-SYN", "", srcStr, dstStr, true, srcMac, set.Name) - - if v == IPv4 { - modsyn := raw - - if set.TCP.SynFake { - w.sendFakeSyn(set, raw, ihl, datOff) - } - - if set.Fragmentation.Strategy != config.ConfigNone && set.Faking.TCPMD5 { - w.sendFakeSynWithMD5(set, raw, ihl, dst) - } + verdict := w.ProcessPacket(raw) - _ = w.sock.SendIPv4(modsyn, dst) - } else { - if set.TCP.SynFake { - w.sendFakeSynV6(set, raw, ihl, datOff) - } - - if set.Fragmentation.Strategy != config.ConfigNone && set.Faking.TCPMD5 { - w.sendFakeSynWithMD5V6(set, raw, dst) - } - - _ = w.sock.SendIPv6(raw, dst) - } - - if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } - return 0 - } - - host := "" - matchedIP := st != nil - matchedSNI := false - ipTarget := "" - sniTarget := "" - - // Show port-matched set name in log - if !matchedIP && matched && set != nil { - ipTarget = set.Name - } - - if cfg.IsTCPPort(dport) && len(payload) > 0 { - log.Tracef("TCP payload to %s: len=%d, first5=%x", dstStr, len(payload), payload[:min(5, len(payload))]) - if len(payload) >= 5 && payload[0] == 0x16 { - log.Tracef("TLS record: type=%x ver=%x%x len=%d", payload[0], payload[1], payload[2], - int(payload[3])<<8|int(payload[4])) - } - connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) - - host, _ = sni.ParseTLSClientHelloSNI(payload) - - if captureManager := capture.GetManager(cfg); captureManager != nil { - captureManager.CapturePayload(connKey, host, "tls", payload) - } - - if host != "" { - if mSNI, stSNI := matcher.MatchSNIWithSource(host, srcMac); mSNI { - // If SNI-matched set has a port filter, verify port matches (AND logic) - if stSNI.MatchesTCPDPort(dport) { - matchedSNI = true - matched = true - set = stSNI - matcher.LearnIPToDomain(dst, host, stSNI) - } - } - } - } - - if matchedIP { - ipTarget = st.Name - } - if matchedSNI { - sniTarget = set.Name + switch verdict { + case engine.VerdictDrop: + if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { + log.Tracef("failed to set drop verdict on packet %d: %v", id, err) } - - if !log.IsDiscoveryActive() { - log.Infof(",TCP,%s,%s,%s:%d,%s,%s:%d,%s", sniTarget, host, srcStr, sport, ipTarget, dstStr, dport, srcMac) - } - - { - m := metrics.GetMetricsCollector() - setName := "" - if matched { - setName = set.Name - } - m.RecordConnection("TCP", host, srcStr, dstStr, matched, srcMac, setName) - m.RecordPacket(uint64(len(raw))) - } - - if matched { - if set.TCP.Incoming.Mode != config.ConfigOff { - connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) - connState.RegisterOutgoing(connKey, set) - } - - packetCopy := make([]byte, len(raw)) - copy(packetCopy, raw) - - if set.TCP.DropSACK { - if v == 4 { - packetCopy = sock.StripSACKFromTCP(packetCopy) - } else { - packetCopy = sock.StripSACKFromTCPv6(packetCopy) - } - } - - dstCopy := make(net.IP, len(dst)) - copy(dstCopy, dst) - setCopy := set - - if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - return 0 - } - - w.wg.Add(1) - go func(s *config.SetConfig, pkt []byte, d net.IP) { - defer w.wg.Done() - if v == 4 { - w.dropAndInjectTCP(s, pkt, d) - } else { - w.dropAndInjectTCPv6(s, pkt, d) - } - }(setCopy, packetCopy, dstCopy) - return 0 - } - + default: if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { log.Tracef("failed to set verdict on packet %d: %v", id, err) } - return 0 - } - - if proto == 17 && len(raw) >= ihl+8 { - udp := raw[ihl:] - if len(udp) < 8 { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - - payload := udp[8:] - sport := binary.BigEndian.Uint16(udp[0:2]) - dport := binary.BigEndian.Uint16(udp[2:4]) - connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) - - if sport == 53 || dport == 53 { - return w.processDnsPacket(v, sport, dport, payload, raw, ihl, id, srcMac) - } - - if utils.IsPrivateIP(dst) { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - - matchedIP := st != nil - matchedQUIC := false - isSTUN := false - host := "" - ipTarget := "" - sniTarget := "" - - // If IP matched but set has a port filter, verify port matches (AND logic) - if matchedIP && !st.MatchesUDPDPort(dport) { - matchedIP = false - matched = false - set = nil - } - - if matchedIP { - ipTarget = st.Name - } - - if !matchedIP { - if mLearned, learnedSet, learnedDomain := matcher.MatchLearnedIPWithSource(dst, srcMac); mLearned { - // If learned IP set has a port filter, verify port matches (AND logic) - if learnedSet.MatchesUDPDPort(dport) { - matchedIP = true - matched = true - set = learnedSet - host = learnedDomain - sniTarget = learnedSet.Name - ipTarget = learnedSet.Name - } - } - } - - // If IP matching didn't find a set, try UDP port-based set matching - matchedPort := false - if !matched { - if portMatched, portSet := matcher.MatchUDPPort(dport); portMatched { - matchedPort = true - matched = true - set = portSet - ipTarget = portSet.Name - } - } - - isSTUN = stun.IsSTUNMessage(payload) - - if host == "" { - if h, ok := sni.ParseQUICClientHelloSNI(payload); ok { - host = h - } - } - - if host != "" { - if mSNI, sniSet := matcher.MatchSNIWithSource(host, srcMac); mSNI { - // If SNI-matched set has a port filter, verify port matches (AND logic) - if sniSet.MatchesUDPDPort(dport) { - matchedQUIC = true - set = sniSet - sniTarget = sniSet.Name - matcher.LearnIPToDomain(dst, host, sniSet) - } - } - } - - if !matchedQUIC && (matchedIP || matchedPort) && set.UDP.FilterQUIC == "all" { - if quic.IsInitial(payload) { - matchedQUIC = true - } - } - - if captureManager := capture.GetManager(cfg); captureManager != nil { - captureManager.CapturePayload(connKey, host, "quic", payload) - } - - shouldHandle := (matchedIP || matchedQUIC || matchedPort) && !(isSTUN && set.UDP.FilterSTUN) - - matched = shouldHandle - - if !log.IsDiscoveryActive() { - log.Infof(",UDP,%s,%s,%s:%d,%s,%s:%d,%s", sniTarget, host, srcStr, sport, ipTarget, dstStr, dport, srcMac) - } - - if isSTUN && set != nil && set.UDP.FilterSTUN { - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - - if !shouldHandle { - m := metrics.GetMetricsCollector() - m.RecordConnection("UDP", host, srcStr, dstStr, false, srcMac, "") - m.RecordPacket(uint64(len(raw))) - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - - metrics := metrics.GetMetricsCollector() - setName := "" - if matched { - setName = set.Name - } - metrics.RecordConnection("UDP", host, srcStr, dstStr, matched, srcMac, setName) - metrics.RecordPacket(uint64(len(raw))) - - switch set.UDP.Mode { - case "drop": - if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on packet %d: %v", id, err) - } - return 0 - - case "fake": - packetCopy := make([]byte, len(raw)) - copy(packetCopy, raw) - dstCopy := make(net.IP, len(dst)) - copy(dstCopy, dst) - setCopy := set - - if err := q.SetVerdict(id, nfqueue.NfDrop); err != nil { - log.Tracef("failed to set drop verdict on UDP packet %d: %v", id, err) - return 0 - } - - w.wg.Add(1) - go func(s *config.SetConfig, pkt []byte, d net.IP) { - defer w.wg.Done() - if v == IPv4 { - w.dropAndInjectQUIC(s, pkt, d) - } else { - w.dropAndInjectQUICV6(s, pkt, d) - } - }(setCopy, packetCopy, dstCopy) - return 0 - - default: - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) - } - return 0 - } - } - - if err := q.SetVerdict(id, nfqueue.NfAccept); err != nil { - log.Tracef("failed to set verdict on packet %d: %v", id, err) } return 0 }, func(e error) int { @@ -955,6 +479,19 @@ func (w *Worker) getMacByIp(ip string) string { return "" } +// InitSender initializes the raw socket sender for the worker. +// This is used by the TUN backend which doesn't open an NFQUEUE. +func (w *Worker) InitSender() error { + cfg := w.getConfig() + mark := cfg.Queue.Mark + s, err := sock.NewSenderWithMark(int(mark)) + if err != nil { + return err + } + w.sock = s + return nil +} + func (w *Worker) Stop() { if w.cancel != nil { w.cancel() diff --git a/src/nfq/process.go b/src/nfq/process.go new file mode 100644 index 00000000..7854dc5c --- /dev/null +++ b/src/nfq/process.go @@ -0,0 +1,444 @@ +package nfq + +import ( + "encoding/binary" + "fmt" + "net" + "sync/atomic" + + "github.com/daniellavrushin/b4/capture" + "github.com/daniellavrushin/b4/config" + "github.com/daniellavrushin/b4/engine" + "github.com/daniellavrushin/b4/log" + "github.com/daniellavrushin/b4/metrics" + "github.com/daniellavrushin/b4/quic" + "github.com/daniellavrushin/b4/sni" + "github.com/daniellavrushin/b4/sock" + "github.com/daniellavrushin/b4/stun" + "github.com/daniellavrushin/b4/utils" +) + +// ProcessPacket is the engine-agnostic packet processing logic. +// It takes a raw IP packet and returns a verdict indicating whether the +// original packet should be accepted (forwarded unchanged) or dropped +// (modified copies already sent via raw socket). +func (w *Worker) ProcessPacket(raw []byte) engine.PacketVerdict { + if len(raw) == 0 { + return engine.VerdictAccept + } + + cfg := w.getConfig() + var set *config.SetConfig + + matcher := w.getMatcher() + + atomic.AddUint64(&w.packetsProcessed, 1) + + v := raw[0] >> 4 + if v != IPv4 && v != IPv6 { + return engine.VerdictAccept + } + var proto uint8 + var src, dst net.IP + var ihl int + if v == IPv4 { + if len(raw) < 20 { + return engine.VerdictAccept + } + ihl = int(raw[0]&0x0f) * 4 + if len(raw) < ihl { + return engine.VerdictAccept + } + + fragOffset := binary.BigEndian.Uint16(raw[6:8]) & 0x1FFF + moreFragments := (binary.BigEndian.Uint16(raw[6:8]) & 0x2000) != 0 + + if fragOffset != 0 || moreFragments { + return engine.VerdictAccept + } + + proto = raw[9] + src = net.IP(raw[12:16]) + dst = net.IP(raw[16:20]) + + } else { + if len(raw) < IPv6HeaderLen { + return engine.VerdictAccept + } + ihl = IPv6HeaderLen + nextHeader := raw[6] + offset := 40 + + for { + switch nextHeader { + case 0, 43, 60: + if len(raw) < offset+2 { + return engine.VerdictAccept + } + nextHeader = raw[offset] + hdrLen := int(raw[offset+1])*8 + 8 + offset += hdrLen + case 44: + return engine.VerdictAccept + default: + goto done + } + } + done: + proto = nextHeader + ihl = offset + src = net.IP(raw[8:24]) + dst = net.IP(raw[24:40]) + } + + if src.IsLoopback() || dst.IsLoopback() { + return engine.VerdictAccept + } + srcStr := src.String() + dstStr := dst.String() + + srcMac := w.getMacByIp(srcStr) + + matched, st := matcher.MatchIPWithSource(dst, srcMac) + if matched { + set = st + } + + if proto == 6 && len(raw) >= ihl+TCPHeaderMinLen { + tcp := raw[ihl:] + if len(tcp) < TCPHeaderMinLen { + return engine.VerdictAccept + } + datOff := int((tcp[12]>>4)&0x0f) * 4 + if len(tcp) < datOff { + return engine.VerdictAccept + } + payload := tcp[datOff:] + sport := binary.BigEndian.Uint16(tcp[0:2]) + dport := binary.BigEndian.Uint16(tcp[2:4]) + + if cfg.IsTCPPort(sport) { + return w.HandleIncoming(v, raw, ihl, src, dstStr, dport, srcStr, sport, payload) + } + + // If IP matched but set has a port filter, verify port matches (AND logic) + if matched && !set.MatchesTCPDPort(dport) { + matched = false + set = nil + } + + // If IP matching didn't find a set, try TCP port-based set matching + if !matched && cfg.IsTCPPort(dport) { + if portMatched, portSet := matcher.MatchTCPPort(dport); portMatched { + matched = true + set = portSet + } + } + + // Packet duplication path: duplicate ALL outgoing TCP packets on configured ports + // without TLS/SNI parsing. Bypasses DPI evasion entirely. + if matched && cfg.IsTCPPort(dport) && set.TCP.Duplicate.Enabled && set.TCP.Duplicate.Count > 0 { + log.Tracef("TCP duplicate to %s:%d (%d copies, set: %s)", dstStr, dport, set.TCP.Duplicate.Count, set.Name) + + m := metrics.GetMetricsCollector() + m.RecordConnection("TCP-DUP", "", srcStr, dstStr, true, srcMac, set.Name) + m.RecordPacket(uint64(len(raw))) + + if !log.IsDiscoveryActive() { + log.Infof(",TCP-DUP,,,%s:%d,%s,%s:%d,%s", srcStr, sport, set.Name, dstStr, dport, srcMac) + } + + for i := 0; i < set.TCP.Duplicate.Count; i++ { + if v == IPv4 { + _ = w.sock.SendIPv4(raw, dst) + } else { + _ = w.sock.SendIPv6(raw, dst) + } + } + return engine.VerdictDrop + } + + tcpFlags := tcp[13] + isSyn := (tcpFlags & 0x02) != 0 + isAck := (tcpFlags & 0x10) != 0 + isRst := (tcpFlags & 0x04) != 0 + if isRst && cfg.IsTCPPort(dport) { + log.Tracef("RST received from %s:%d", dstStr, dport) + } + + if isSyn && !isAck && cfg.IsTCPPort(dport) && matched && !set.TCP.Duplicate.Enabled { + log.Tracef("TCP SYN to %s:%d (set: %s)", dstStr, dport, set.Name) + + m := metrics.GetMetricsCollector() + m.RecordConnection("TCP-SYN", "", srcStr, dstStr, true, srcMac, set.Name) + + if v == IPv4 { + modsyn := raw + + if set.TCP.SynFake { + w.sendFakeSyn(set, raw, ihl, datOff) + } + + if set.Fragmentation.Strategy != config.ConfigNone && set.Faking.TCPMD5 { + w.sendFakeSynWithMD5(set, raw, ihl, dst) + } + + _ = w.sock.SendIPv4(modsyn, dst) + } else { + if set.TCP.SynFake { + w.sendFakeSynV6(set, raw, ihl, datOff) + } + + if set.Fragmentation.Strategy != config.ConfigNone && set.Faking.TCPMD5 { + w.sendFakeSynWithMD5V6(set, raw, dst) + } + + _ = w.sock.SendIPv6(raw, dst) + } + + return engine.VerdictDrop + } + + host := "" + matchedIP := st != nil + matchedSNI := false + ipTarget := "" + sniTarget := "" + + // Show port-matched set name in log + if !matchedIP && matched && set != nil { + ipTarget = set.Name + } + + if cfg.IsTCPPort(dport) && len(payload) > 0 { + log.Tracef("TCP payload to %s: len=%d, first5=%x", dstStr, len(payload), payload[:min(5, len(payload))]) + if len(payload) >= 5 && payload[0] == 0x16 { + log.Tracef("TLS record: type=%x ver=%x%x len=%d", payload[0], payload[1], payload[2], + int(payload[3])<<8|int(payload[4])) + } + connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) + + host, _ = sni.ParseTLSClientHelloSNI(payload) + + if captureManager := capture.GetManager(cfg); captureManager != nil { + captureManager.CapturePayload(connKey, host, "tls", payload) + } + + if host != "" { + if mSNI, stSNI := matcher.MatchSNIWithSource(host, srcMac); mSNI { + // If SNI-matched set has a port filter, verify port matches (AND logic) + if stSNI.MatchesTCPDPort(dport) { + matchedSNI = true + matched = true + set = stSNI + matcher.LearnIPToDomain(dst, host, stSNI) + } + } + } + } + + if matchedIP { + ipTarget = st.Name + } + if matchedSNI { + sniTarget = set.Name + } + + if !log.IsDiscoveryActive() { + log.Infof(",TCP,%s,%s,%s:%d,%s,%s:%d,%s", sniTarget, host, srcStr, sport, ipTarget, dstStr, dport, srcMac) + } + + { + m := metrics.GetMetricsCollector() + setName := "" + if matched { + setName = set.Name + } + m.RecordConnection("TCP", host, srcStr, dstStr, matched, srcMac, setName) + m.RecordPacket(uint64(len(raw))) + } + + if matched { + if set.TCP.Incoming.Mode != config.ConfigOff { + connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) + connState.RegisterOutgoing(connKey, set) + } + + packetCopy := make([]byte, len(raw)) + copy(packetCopy, raw) + + if set.TCP.DropSACK { + if v == 4 { + packetCopy = sock.StripSACKFromTCP(packetCopy) + } else { + packetCopy = sock.StripSACKFromTCPv6(packetCopy) + } + } + + dstCopy := make(net.IP, len(dst)) + copy(dstCopy, dst) + setCopy := set + + w.wg.Add(1) + go func(s *config.SetConfig, pkt []byte, d net.IP) { + defer w.wg.Done() + if v == 4 { + w.dropAndInjectTCP(s, pkt, d) + } else { + w.dropAndInjectTCPv6(s, pkt, d) + } + }(setCopy, packetCopy, dstCopy) + return engine.VerdictDrop + } + + return engine.VerdictAccept + } + + if proto == 17 && len(raw) >= ihl+8 { + udp := raw[ihl:] + if len(udp) < 8 { + return engine.VerdictAccept + } + + payload := udp[8:] + sport := binary.BigEndian.Uint16(udp[0:2]) + dport := binary.BigEndian.Uint16(udp[2:4]) + connKey := fmt.Sprintf(connKeyFormat, srcStr, sport, dstStr, dport) + + if sport == 53 || dport == 53 { + return w.processDnsPacket(v, sport, dport, payload, raw, ihl, srcMac) + } + + if utils.IsPrivateIP(dst) { + return engine.VerdictAccept + } + + matchedIP := st != nil + matchedQUIC := false + isSTUN := false + host := "" + ipTarget := "" + sniTarget := "" + + // If IP matched but set has a port filter, verify port matches (AND logic) + if matchedIP && !st.MatchesUDPDPort(dport) { + matchedIP = false + matched = false + set = nil + } + + if matchedIP { + ipTarget = st.Name + } + + if !matchedIP { + if mLearned, learnedSet, learnedDomain := matcher.MatchLearnedIPWithSource(dst, srcMac); mLearned { + // If learned IP set has a port filter, verify port matches (AND logic) + if learnedSet.MatchesUDPDPort(dport) { + matchedIP = true + matched = true + set = learnedSet + host = learnedDomain + sniTarget = learnedSet.Name + ipTarget = learnedSet.Name + } + } + } + + // If IP matching didn't find a set, try UDP port-based set matching + matchedPort := false + if !matched { + if portMatched, portSet := matcher.MatchUDPPort(dport); portMatched { + matchedPort = true + matched = true + set = portSet + ipTarget = portSet.Name + } + } + + isSTUN = stun.IsSTUNMessage(payload) + + if host == "" { + if h, ok := sni.ParseQUICClientHelloSNI(payload); ok { + host = h + } + } + + if host != "" { + if mSNI, sniSet := matcher.MatchSNIWithSource(host, srcMac); mSNI { + // If SNI-matched set has a port filter, verify port matches (AND logic) + if sniSet.MatchesUDPDPort(dport) { + matchedQUIC = true + set = sniSet + sniTarget = sniSet.Name + matcher.LearnIPToDomain(dst, host, sniSet) + } + } + } + + if !matchedQUIC && (matchedIP || matchedPort) && set.UDP.FilterQUIC == "all" { + if quic.IsInitial(payload) { + matchedQUIC = true + } + } + + if captureManager := capture.GetManager(cfg); captureManager != nil { + captureManager.CapturePayload(connKey, host, "quic", payload) + } + + shouldHandle := (matchedIP || matchedQUIC || matchedPort) && !(isSTUN && set.UDP.FilterSTUN) + + matched = shouldHandle + + if !log.IsDiscoveryActive() { + log.Infof(",UDP,%s,%s,%s:%d,%s,%s:%d,%s", sniTarget, host, srcStr, sport, ipTarget, dstStr, dport, srcMac) + } + + if isSTUN && set != nil && set.UDP.FilterSTUN { + return engine.VerdictAccept + } + + if !shouldHandle { + m := metrics.GetMetricsCollector() + m.RecordConnection("UDP", host, srcStr, dstStr, false, srcMac, "") + m.RecordPacket(uint64(len(raw))) + return engine.VerdictAccept + } + + m := metrics.GetMetricsCollector() + setName := "" + if matched { + setName = set.Name + } + m.RecordConnection("UDP", host, srcStr, dstStr, matched, srcMac, setName) + m.RecordPacket(uint64(len(raw))) + + switch set.UDP.Mode { + case "drop": + return engine.VerdictDrop + + case "fake": + packetCopy := make([]byte, len(raw)) + copy(packetCopy, raw) + dstCopy := make(net.IP, len(dst)) + copy(dstCopy, dst) + setCopy := set + + w.wg.Add(1) + go func(s *config.SetConfig, pkt []byte, d net.IP) { + defer w.wg.Done() + if v == IPv4 { + w.dropAndInjectQUIC(s, pkt, d) + } else { + w.dropAndInjectQUICV6(s, pkt, d) + } + }(setCopy, packetCopy, dstCopy) + return engine.VerdictDrop + + default: + return engine.VerdictAccept + } + } + + return engine.VerdictAccept +} diff --git a/src/tun/device.go b/src/tun/device.go new file mode 100644 index 00000000..8ad604eb --- /dev/null +++ b/src/tun/device.go @@ -0,0 +1,57 @@ +package tun + +import ( + "fmt" + "os" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + tunDevice = "/dev/net/tun" + ifnamsiz = 16 + iffTun = 0x0001 + iffNoPi = 0x1000 + tunsetiff = 0x400454ca +) + +// ifreqFlags matches the struct ifreq for ioctl TUNSETIFF. +type ifreqFlags struct { + Name [ifnamsiz]byte + Flags uint16 + _ [22]byte // padding to match sizeof(struct ifreq) +} + +// openTUN creates a TUN device with the given name and returns the file descriptor. +func openTUN(name string) (*os.File, string, error) { + fd, err := unix.Open(tunDevice, unix.O_RDWR|unix.O_CLOEXEC, 0) + if err != nil { + return nil, "", fmt.Errorf("open %s: %w", tunDevice, err) + } + + var ifr ifreqFlags + copy(ifr.Name[:], name) + ifr.Flags = iffTun | iffNoPi + + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(tunsetiff), uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + unix.Close(fd) + return nil, "", fmt.Errorf("ioctl TUNSETIFF: %w", errno) + } + + // Extract actual device name (kernel may have assigned tunN if name was empty) + actualName := "" + for i, b := range ifr.Name { + if b == 0 { + actualName = string(ifr.Name[:i]) + break + } + } + if actualName == "" { + actualName = name + } + + file := os.NewFile(uintptr(fd), tunDevice) + return file, actualName, nil +} diff --git a/src/tun/route.go b/src/tun/route.go new file mode 100644 index 00000000..2618346a --- /dev/null +++ b/src/tun/route.go @@ -0,0 +1,154 @@ +package tun + +import ( + "fmt" + "os/exec" + "strings" + + "github.com/daniellavrushin/b4/log" +) + +// routeManager handles setting up and tearing down routing rules for TUN mode. +type routeManager struct { + tunName string + tunAddr string // e.g. "10.255.0.1/30" + outIface string // e.g. "eth0" + outGateway string // e.g. "192.168.1.1" + mark uint + routeTable int + savedDefault string // original default route for restoration +} + +func newRouteManager(tunName, tunAddr, outIface, outGateway string, mark uint, routeTable int) *routeManager { + return &routeManager{ + tunName: tunName, + tunAddr: tunAddr, + outIface: outIface, + outGateway: outGateway, + mark: mark, + routeTable: routeTable, + } +} + +// setup configures routing so traffic flows through the TUN device, +// while b4's own outbound packets (marked with fwmark) bypass the TUN. +func (r *routeManager) setup() error { + // Save current default route for restoration + out, err := run("ip", "route", "show", "default") + if err != nil { + return fmt.Errorf("failed to read current default route: %w", err) + } + r.savedDefault = strings.TrimSpace(out) + log.Infof("TUN: saved default route: %s", r.savedDefault) + + // Auto-detect gateway if not specified + if r.outGateway == "" { + gw := extractGateway(r.savedDefault) + if gw == "" { + return fmt.Errorf("could not auto-detect gateway from default route: %s", r.savedDefault) + } + r.outGateway = gw + log.Infof("TUN: auto-detected gateway: %s", r.outGateway) + } + + // 1. Configure TUN device + if _, err := run("ip", "addr", "add", r.tunAddr, "dev", r.tunName); err != nil { + return fmt.Errorf("ip addr add: %w", err) + } + if _, err := run("ip", "link", "set", r.tunName, "up"); err != nil { + return fmt.Errorf("ip link set up: %w", err) + } + if _, err := run("ip", "link", "set", r.tunName, "mtu", "1500"); err != nil { + log.Warnf("TUN: failed to set MTU: %v", err) + } + + // 2. Policy routing: marked packets use a separate table that routes via the real interface + markStr := fmt.Sprintf("0x%x", r.mark) + tableStr := fmt.Sprintf("%d", r.routeTable) + + // Clean up stale rules/routes from a previous run that didn't shut down cleanly + run("ip", "rule", "del", "fwmark", markStr, "lookup", tableStr) + run("ip", "route", "flush", "table", tableStr) + + if _, err := run("ip", "rule", "add", "fwmark", markStr, "lookup", tableStr, "priority", "100"); err != nil { + return fmt.Errorf("ip rule add: %w", err) + } + if _, err := run("ip", "route", "add", "default", "via", r.outGateway, "dev", r.outIface, "table", tableStr); err != nil { + return fmt.Errorf("ip route add table: %w", err) + } + + // 3. Replace default route to go through TUN, preserving the original source IP + // so the kernel doesn't assign the TUN address (e.g. 10.255.0.1) as source + srcIP := extractField(r.savedDefault, "src") + if srcIP != "" { + if _, err := run("ip", "route", "replace", "default", "dev", r.tunName, "src", srcIP); err != nil { + return fmt.Errorf("ip route replace default: %w", err) + } + } else { + if _, err := run("ip", "route", "replace", "default", "dev", r.tunName); err != nil { + return fmt.Errorf("ip route replace default: %w", err) + } + } + + log.Infof("TUN: routing configured (tun=%s, out=%s via %s, mark=0x%x, table=%d)", + r.tunName, r.outIface, r.outGateway, r.mark, r.routeTable) + + return nil +} + +// teardown restores the original routing configuration. +func (r *routeManager) teardown() { + markStr := fmt.Sprintf("0x%x", r.mark) + tableStr := fmt.Sprintf("%d", r.routeTable) + + // Restore original default route + if r.savedDefault != "" { + args := append([]string{"ip", "route", "replace"}, strings.Fields(r.savedDefault)...) + if _, err := run(args...); err != nil { + log.Errorf("TUN: failed to restore default route: %v", err) + } else { + log.Infof("TUN: restored default route: %s", r.savedDefault) + } + } + + // Clean up policy routing + if _, err := run("ip", "rule", "del", "fwmark", markStr, "lookup", tableStr); err != nil { + log.Warnf("TUN: failed to delete ip rule: %v", err) + } + if _, err := run("ip", "route", "flush", "table", tableStr); err != nil { + log.Warnf("TUN: failed to flush route table %s: %v", tableStr, err) + } + + // Remove TUN device + if _, err := run("ip", "link", "del", r.tunName); err != nil { + log.Warnf("TUN: failed to delete %s: %v", r.tunName, err) + } + + log.Infof("TUN: routing teardown complete") +} + +// extractField parses a route line for a keyword and returns the next token. +// e.g. extractField("default via 1.2.3.4 dev eth0 src 10.0.0.1", "via") => "1.2.3.4" +func extractField(routeLine, keyword string) string { + parts := strings.Fields(routeLine) + for i, p := range parts { + if p == keyword && i+1 < len(parts) { + return parts[i+1] + } + } + return "" +} + +// extractGateway parses "default via X.X.X.X dev Y" to get the gateway IP. +func extractGateway(routeLine string) string { + return extractField(routeLine, "via") +} + +func run(args ...string) (string, error) { + cmd := exec.Command(args[0], args[1:]...) + out, err := cmd.CombinedOutput() + if err != nil { + return string(out), fmt.Errorf("%s: %w (%s)", strings.Join(args, " "), err, strings.TrimSpace(string(out))) + } + return string(out), nil +} diff --git a/src/tun/tun.go b/src/tun/tun.go new file mode 100644 index 00000000..33f23871 --- /dev/null +++ b/src/tun/tun.go @@ -0,0 +1,184 @@ +package tun + +import ( + "os" + "sync" + + "github.com/daniellavrushin/b4/config" + "github.com/daniellavrushin/b4/engine" + "github.com/daniellavrushin/b4/log" + "github.com/daniellavrushin/b4/nfq" + "github.com/daniellavrushin/b4/sock" +) + +const tunBufSize = 65536 + +// Engine implements the TUN-based packet processing backend. +// It creates a TUN device, routes traffic through it, and processes +// packets using the same Worker.ProcessPacket logic as NFQUEUE mode. +type Engine struct { + cfg *config.Config + pool *nfq.Pool + tunFile *os.File + tunName string + routes *routeManager + sender *sock.Sender + wg sync.WaitGroup + quit chan struct{} +} + +// NewEngine creates a new TUN engine. It reuses the nfq.Pool for packet +// processing workers and the same matching/evasion logic. +func NewEngine(cfg *config.Config, pool *nfq.Pool) *Engine { + return &Engine{ + cfg: cfg, + pool: pool, + quit: make(chan struct{}), + } +} + +// Start opens the TUN device, sets up routing, and starts read loops. +func (e *Engine) Start() error { + tunCfg := &e.cfg.Queue.TUN + + // Initialize sender for each worker (without opening NFQUEUE) + for _, w := range e.pool.Workers { + if err := w.InitSender(); err != nil { + return err + } + } + + // Clean up stale TUN device from a previous unclean shutdown + run("ip", "link", "del", tunCfg.DeviceName) + + // Open TUN device + f, name, err := openTUN(tunCfg.DeviceName) + if err != nil { + return err + } + e.tunFile = f + e.tunName = name + log.Infof("TUN: opened device %s", name) + + // Create a sender for forwarding unmatched packets + sender, err := sock.NewSenderWithMark(int(e.cfg.Queue.Mark)) + if err != nil { + e.tunFile.Close() + return err + } + e.sender = sender + + // Setup routing + e.routes = newRouteManager( + name, + tunCfg.Address, + tunCfg.OutInterface, + tunCfg.OutGateway, + e.cfg.Queue.Mark, + tunCfg.RouteTable, + ) + if err := e.routes.setup(); err != nil { + e.sender.Close() + e.tunFile.Close() + return err + } + + // Start reader goroutines (one per worker thread for parallelism) + threads := e.cfg.Queue.Threads + if threads < 1 { + threads = 1 + } + for i := 0; i < threads; i++ { + e.wg.Add(1) + go e.readLoop(i) + } + + log.Infof("TUN: started %d reader threads", threads) + return nil +} + +// readLoop reads packets from the TUN device and processes them. +func (e *Engine) readLoop(workerIdx int) { + defer e.wg.Done() + + worker := e.pool.Workers[workerIdx%len(e.pool.Workers)] + buf := make([]byte, tunBufSize) + + for { + select { + case <-e.quit: + return + default: + } + + n, err := e.tunFile.Read(buf) + if err != nil { + select { + case <-e.quit: + return + default: + } + log.Errorf("TUN: read error: %v", err) + continue + } + + if n == 0 { + continue + } + + // Make a copy since buf is reused + raw := make([]byte, n) + copy(raw, buf[:n]) + + verdict := worker.ProcessPacket(raw) + + if verdict == engine.VerdictAccept { + // Forward the packet unchanged via raw socket (marked, bypasses TUN) + e.forwardPacket(raw) + } + // VerdictDrop: ProcessPacket already sent modified copies via raw socket + } +} + +// forwardPacket sends an unmodified packet out via the real interface. +func (e *Engine) forwardPacket(raw []byte) { + if len(raw) == 0 { + return + } + v := raw[0] >> 4 + switch v { + case 4: + if len(raw) < 20 { + return + } + dst := raw[16:20] + _ = e.sender.SendIPv4(raw, dst) + case 6: + if len(raw) < 40 { + return + } + dst := raw[24:40] + _ = e.sender.SendIPv6(raw, dst) + } +} + +// Stop tears down routing and closes the TUN device. +func (e *Engine) Stop() { + close(e.quit) + + // Close TUN fd to unblock readers + if e.tunFile != nil { + e.tunFile.Close() + } + + e.wg.Wait() + + if e.routes != nil { + e.routes.teardown() + } + if e.sender != nil { + e.sender.Close() + } + + log.Infof("TUN: engine stopped") +}