diff --git a/pkg/cmd/tunnel/cmd.go b/pkg/cmd/tunnel/cmd.go index c96dbac2..c216baca 100644 --- a/pkg/cmd/tunnel/cmd.go +++ b/pkg/cmd/tunnel/cmd.go @@ -225,6 +225,7 @@ func init() { tunnelRunCmd.Flags().BoolVar(&autoCreate, "auto", false, "Automatically create TunnelNode if it doesn't exist.") tunnelRunCmd.Flags().StringVar(&healthAddr, "health-addr", ":8080", "Listen address for health endpoint (default: :8080).") tunnelRunCmd.Flags().StringVar(&metricsAddr, "metrics-addr", ":8081", "Listen address for metrics endpoint (default: :8081).") + tunnelRunCmd.Flags().StringVar(&overridePort, "port", "", "Override destination port for forwarded packets (empty string preserves original port).") tunnelCmd.AddCommand(createCmd) tunnelCmd.AddCommand(getCmd) diff --git a/pkg/cmd/tunnel/run.go b/pkg/cmd/tunnel/run.go index 2b6f55ca..f47bcec6 100644 --- a/pkg/cmd/tunnel/run.go +++ b/pkg/cmd/tunnel/run.go @@ -75,6 +75,7 @@ var ( autoCreate bool healthAddr string metricsAddr string + overridePort string preserveDefaultGwDsts []netip.Prefix ) @@ -264,6 +265,7 @@ func (t *tunnelNodeReconciler) run(ctx context.Context, tn *corev1alpha.TunnelNo tunnel.WithMode(tunnelMode), tunnel.WithPreserveDefaultGatewayDestinations(preserveDefaultGwDsts), tunnel.WithSocksListenAddr(socksListenAddr), + tunnel.WithOverridePort(overridePort), ) if err != nil { return fmt.Errorf("unable to build client router: %w", err) diff --git a/pkg/netstack/icx_network.go b/pkg/netstack/icx_network.go index 295ba2b5..e78f365c 100644 --- a/pkg/netstack/icx_network.go +++ b/pkg/netstack/icx_network.go @@ -395,10 +395,10 @@ func (net *ICXNetwork) ForwardTo(ctx context.Context, upstream network.Network) return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr) } - tcpForwarder := TCPForwarder(ctx, net.stack, upstream) + tcpForwarder := TCPForwarder(ctx, net.stack, upstream, "") net.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) - udpForwarder := UDPForwarder(ctx, net.stack, upstream) + udpForwarder := UDPForwarder(ctx, net.stack, upstream, "") net.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder) return nil diff --git a/pkg/netstack/tcp_forwarder.go b/pkg/netstack/tcp_forwarder.go index 8ad125ff..3eea08d3 100644 --- a/pkg/netstack/tcp_forwarder.go +++ b/pkg/netstack/tcp_forwarder.go @@ -19,12 +19,13 @@ import ( type ProtocolHandler func(stack.TransportEndpointID, *stack.PacketBuffer) bool // TCPForwarder forwards TCP connections to an upstream network. -func TCPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network) ProtocolHandler { +// If overridePort is not empty, all connections will be forwarded to that port instead of the original destination port. +func TCPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network, overridePort string) ProtocolHandler { tcpForwarder := tcp.NewForwarder( ipstack, 0, /* rcvWnd (0 - default) */ 65535, /* maxInFlight */ - tcpHandler(ctx, upstream), + tcpHandler(ctx, upstream, overridePort), ) return tcpForwarder.HandlePacket @@ -51,15 +52,24 @@ func Unmap4in6(addr netip.Addr) netip.Addr { return v4addr } -func tcpHandler(ctx context.Context, upstream network.Network) func(req *tcp.ForwarderRequest) { +func tcpHandler(ctx context.Context, upstream network.Network, overridePort string) func(req *tcp.ForwarderRequest) { return func(req *tcp.ForwarderRequest) { reqDetails := req.ID() srcAddrPort := netip.AddrPortFrom(addrFromNetstackIP(reqDetails.RemoteAddress), reqDetails.RemotePort) - dstAddrPort := netip.AddrPortFrom( - Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)), - reqDetails.LocalPort, - ) + + // Extract IPv4 address from IPv6-mapped address + dstAddr := Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)) + dstPort := reqDetails.LocalPort + + // Override port if specified + if overridePort != "" { + if port, err := netip.ParseAddrPort("127.0.0.1:" + overridePort); err == nil { + dstPort = port.Port() + } + } + + dstAddrPort := netip.AddrPortFrom(dstAddr, dstPort) logger := slog.With( slog.String("src", srcAddrPort.String()), diff --git a/pkg/netstack/tcp_forwarder_test.go b/pkg/netstack/tcp_forwarder_test.go index 28143290..bb0cadff 100644 --- a/pkg/netstack/tcp_forwarder_test.go +++ b/pkg/netstack/tcp_forwarder_test.go @@ -50,7 +50,7 @@ func TestTCPForwarder(t *testing.T) { }() // Setup the server stack to forward TCP packets to the hosts loopback interface. - serverStack.SetTransportProtocolHandler(tcp.ProtocolNumber, netstack.TCPForwarder(ctx, serverStack.Stack, network.Loopback())) + serverStack.SetTransportProtocolHandler(tcp.ProtocolNumber, netstack.TCPForwarder(ctx, serverStack.Stack, network.Loopback(), "")) // Generate a large blob of random blob to send to the client. blob := make([]byte, 1<<20) diff --git a/pkg/netstack/tun_device.go b/pkg/netstack/tun_device.go index 83b8bb2d..5f7c01d7 100644 --- a/pkg/netstack/tun_device.go +++ b/pkg/netstack/tun_device.go @@ -295,6 +295,11 @@ func (tun *TunDevice) LocalAddresses() ([]netip.Prefix, error) { // ForwardTo forwards all inbound traffic to the upstream network. func (tun *TunDevice) ForwardTo(ctx context.Context, upstream network.Network) error { + return tun.ForwardToWithOptions(ctx, upstream, "") +} + +// ForwardToWithOptions forwards all inbound traffic to the upstream network with optional port override. +func (tun *TunDevice) ForwardToWithOptions(ctx context.Context, upstream network.Network, overridePort string) error { // Allow outgoing packets to have a source address different from the address // assigned to the NIC. if tcpipErr := tun.stack.SetSpoofing(tun.nicID, true); tcpipErr != nil { @@ -307,10 +312,10 @@ func (tun *TunDevice) ForwardTo(ctx context.Context, upstream network.Network) e return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr) } - tcpForwarder := TCPForwarder(ctx, tun.stack, upstream) + tcpForwarder := TCPForwarder(ctx, tun.stack, upstream, overridePort) tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder) - udpForwarder := UDPForwarder(ctx, tun.stack, upstream) + udpForwarder := UDPForwarder(ctx, tun.stack, upstream, overridePort) tun.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder) return nil diff --git a/pkg/netstack/udp_forwarder.go b/pkg/netstack/udp_forwarder.go index 9c9a7a51..5fcb86e2 100644 --- a/pkg/netstack/udp_forwarder.go +++ b/pkg/netstack/udp_forwarder.go @@ -28,10 +28,11 @@ var udpBuffPool = sync.Pool{ } // UDPForwarder forwards UDP packets to an upstream network. -func UDPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network) ProtocolHandler { +// If overridePort is not empty, all packets will be forwarded to that port instead of the original destination port. +func UDPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network, overridePort string) ProtocolHandler { udpForwarder := udp.NewForwarder( ipstack, - udpHandler(ctx, upstream), + udpHandler(ctx, upstream, overridePort), ) return udpForwarder.HandlePacket @@ -77,15 +78,24 @@ func copyPackets(ctx context.Context, src, dst net.Conn, once bool, extend func( } } -func udpHandler(ctx context.Context, upstream network.Network) func(req *udp.ForwarderRequest) { +func udpHandler(ctx context.Context, upstream network.Network, overridePort string) func(req *udp.ForwarderRequest) { return func(req *udp.ForwarderRequest) { reqDetails := req.ID() srcAddrPort := netip.AddrPortFrom(addrFromNetstackIP(reqDetails.RemoteAddress), reqDetails.RemotePort) - dstAddrPort := netip.AddrPortFrom( - Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)), - reqDetails.LocalPort, - ) + + // Extract IPv4 address from IPv6-mapped address + dstAddr := Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)) + dstPort := reqDetails.LocalPort + + // Override port if specified + if overridePort != "" { + if port, err := netip.ParseAddrPort("127.0.0.1:" + overridePort); err == nil { + dstPort = port.Port() + } + } + + dstAddrPort := netip.AddrPortFrom(dstAddr, dstPort) logger := slog.With( slog.String("src", srcAddrPort.String()), diff --git a/pkg/netstack/udp_forwarder_test.go b/pkg/netstack/udp_forwarder_test.go index 7de6c7fe..d7af8abc 100644 --- a/pkg/netstack/udp_forwarder_test.go +++ b/pkg/netstack/udp_forwarder_test.go @@ -47,7 +47,7 @@ func TestUDPForwarder(t *testing.T) { }() // Setup the server stack to forward UDP packets to the hosts loopback interface. - serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback())) + serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), "")) // Generate test data testData := make([]byte, 1024) @@ -144,7 +144,7 @@ func TestUDPForwarderMultipleSessions(t *testing.T) { }() // Setup the server stack to forward UDP packets - serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback())) + serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), "")) // Start multiple UDP servers on different ports numServers := 3 @@ -242,7 +242,7 @@ func TestUDPForwarderTimeout(t *testing.T) { }() // Setup the server stack to forward UDP packets - serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback())) + serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), "")) // Start a UDP server udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") diff --git a/pkg/tunnel/client.go b/pkg/tunnel/client.go index 245eff66..54c54795 100644 --- a/pkg/tunnel/client.go +++ b/pkg/tunnel/client.go @@ -66,6 +66,7 @@ type tunnelClientOptions struct { // Userspace options socksListenAddr string preserveDefaultGwDsts []netip.Prefix + overridePort string } func defaultClientOptions() *tunnelClientOptions { @@ -141,6 +142,14 @@ func WithPreserveDefaultGatewayDestinations(dsts []netip.Prefix) TunnelClientOpt } } +// WithOverridePort sets the destination port override for forwarded packets. +// If empty string, the original destination port is preserved. +func WithOverridePort(port string) TunnelClientOption { + return func(o *tunnelClientOptions) { + o.overridePort = port + } +} + // BuildClientRouter builds a router for the client tunnel side using provided // options and sane defaults. func BuildClientRouter(opts ...TunnelClientOption) (router.Router, error) { @@ -174,6 +183,9 @@ func BuildClientRouter(opts ...TunnelClientOption) (router.Router, error) { if options.socksListenAddr != "" { routerOpts = append(routerOpts, router.WithSocksListenAddr(options.socksListenAddr)) } + if options.overridePort != "" { + routerOpts = append(routerOpts, router.WithOverridePort(options.overridePort)) + } switch options.mode { case TunnelClientModeKernel: diff --git a/pkg/tunnel/router/client_netstack.go b/pkg/tunnel/router/client_netstack.go index ef78a951..d574473d 100644 --- a/pkg/tunnel/router/client_netstack.go +++ b/pkg/tunnel/router/client_netstack.go @@ -32,6 +32,7 @@ type NetstackRouter struct { resolveConf *network.ResolveConfig socksListenAddr string cksumRecalc bool + overridePort string closeOnce sync.Once } @@ -63,6 +64,7 @@ func NewNetstackRouter(opts ...Option) (*NetstackRouter, error) { resolveConf: options.resolveConf, socksListenAddr: options.socksListenAddr, cksumRecalc: options.cksumRecalc, + overridePort: options.overridePort, }, nil } @@ -99,10 +101,10 @@ func (r *NetstackRouter) Start(ctx context.Context) error { slog.Info("Forwarding all inbound traffic to loopback interface") - if err := r.tunDev.ForwardTo(ctx, network.Filtered(&network.FilteredNetworkConfig{ + if err := r.tunDev.ForwardToWithOptions(ctx, network.Filtered(&network.FilteredNetworkConfig{ DeniedPorts: []uint16{uint16(socksListenPort)}, Upstream: network.Host(), - })); err != nil { + }), r.overridePort); err != nil { return fmt.Errorf("failed to forward to loopback: %w", err) } diff --git a/pkg/tunnel/router/options.go b/pkg/tunnel/router/options.go index 30d5fc85..b7a2608c 100644 --- a/pkg/tunnel/router/options.go +++ b/pkg/tunnel/router/options.go @@ -25,6 +25,7 @@ type routerOptions struct { sourcePortHashing bool pc batchpc.BatchPacketConn egressGateway bool + overridePort string } func defaultOptions() *routerOptions { @@ -134,3 +135,11 @@ func WithEgressGateway(enable bool) Option { o.egressGateway = enable } } + +// WithOverridePort sets the destination port override for forwarded packets. +// If empty string, the original destination port is preserved. +func WithOverridePort(port string) Option { + return func(o *routerOptions) { + o.overridePort = port + } +}