Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/cmd/tunnel/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ func init() {
tunnelRunCmd.Flags().BoolVar(&autoCreate, "auto", false, "Automatically create TunnelNode if it doesn't exist.")
tunnelRunCmd.Flags().StringVar(&healthAddr, "health-addr", ":8080", "Listen address for health endpoint (default: :8080).")
tunnelRunCmd.Flags().StringVar(&metricsAddr, "metrics-addr", ":8081", "Listen address for metrics endpoint (default: :8081).")
tunnelRunCmd.Flags().StringVar(&overridePort, "port", "", "Override destination port for forwarded packets (empty string preserves original port).")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Port should be a uint


tunnelCmd.AddCommand(createCmd)
tunnelCmd.AddCommand(getCmd)
Expand Down
2 changes: 2 additions & 0 deletions pkg/cmd/tunnel/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ var (
autoCreate bool
healthAddr string
metricsAddr string
overridePort string

preserveDefaultGwDsts []netip.Prefix
)
Expand Down Expand Up @@ -264,6 +265,7 @@ func (t *tunnelNodeReconciler) run(ctx context.Context, tn *corev1alpha.TunnelNo
tunnel.WithMode(tunnelMode),
tunnel.WithPreserveDefaultGatewayDestinations(preserveDefaultGwDsts),
tunnel.WithSocksListenAddr(socksListenAddr),
tunnel.WithOverridePort(overridePort),
)
if err != nil {
return fmt.Errorf("unable to build client router: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/netstack/icx_network.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ func (net *ICXNetwork) ForwardTo(ctx context.Context, upstream network.Network)
return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr)
}

tcpForwarder := TCPForwarder(ctx, net.stack, upstream)
tcpForwarder := TCPForwarder(ctx, net.stack, upstream, "")
net.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder)

udpForwarder := UDPForwarder(ctx, net.stack, upstream)
udpForwarder := UDPForwarder(ctx, net.stack, upstream, "")
net.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder)

return nil
Expand Down
24 changes: 17 additions & 7 deletions pkg/netstack/tcp_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import (
type ProtocolHandler func(stack.TransportEndpointID, *stack.PacketBuffer) bool

// TCPForwarder forwards TCP connections to an upstream network.
func TCPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network) ProtocolHandler {
// If overridePort is not empty, all connections will be forwarded to that port instead of the original destination port.
func TCPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network, overridePort string) ProtocolHandler {
tcpForwarder := tcp.NewForwarder(
ipstack,
0, /* rcvWnd (0 - default) */
65535, /* maxInFlight */
tcpHandler(ctx, upstream),
tcpHandler(ctx, upstream, overridePort),
)

return tcpForwarder.HandlePacket
Expand All @@ -51,15 +52,24 @@ func Unmap4in6(addr netip.Addr) netip.Addr {
return v4addr
}

func tcpHandler(ctx context.Context, upstream network.Network) func(req *tcp.ForwarderRequest) {
func tcpHandler(ctx context.Context, upstream network.Network, overridePort string) func(req *tcp.ForwarderRequest) {
return func(req *tcp.ForwarderRequest) {
reqDetails := req.ID()

srcAddrPort := netip.AddrPortFrom(addrFromNetstackIP(reqDetails.RemoteAddress), reqDetails.RemotePort)
dstAddrPort := netip.AddrPortFrom(
Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)),
reqDetails.LocalPort,
)

// Extract IPv4 address from IPv6-mapped address
dstAddr := Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress))
dstPort := reqDetails.LocalPort

Comment on lines +60 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Extract IPv4 address from IPv6-mapped address
dstAddr := Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress))
dstPort := reqDetails.LocalPort
dstAddrPort := netip.AddrPortFrom(
Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)),
reqDetails.LocalPort,
)
if overridePort != 0 {
logger.Debug("Overriding port", "original_port", dstAddrPort.Port, "ovveride_port", overridePort)
}

These comments are not needed

// Override port if specified
if overridePort != "" {
if port, err := netip.ParseAddrPort("127.0.0.1:" + overridePort); err == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just strconv.Atoi instead - this is confusing (why is there 127.0.0.1 in there?) Also it swallows the error

dstPort = port.Port()
}
}

dstAddrPort := netip.AddrPortFrom(dstAddr, dstPort)

logger := slog.With(
slog.String("src", srcAddrPort.String()),
Expand Down
2 changes: 1 addition & 1 deletion pkg/netstack/tcp_forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestTCPForwarder(t *testing.T) {
}()

// Setup the server stack to forward TCP packets to the hosts loopback interface.
serverStack.SetTransportProtocolHandler(tcp.ProtocolNumber, netstack.TCPForwarder(ctx, serverStack.Stack, network.Loopback()))
serverStack.SetTransportProtocolHandler(tcp.ProtocolNumber, netstack.TCPForwarder(ctx, serverStack.Stack, network.Loopback(), ""))

// Generate a large blob of random blob to send to the client.
blob := make([]byte, 1<<20)
Expand Down
9 changes: 7 additions & 2 deletions pkg/netstack/tun_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ func (tun *TunDevice) LocalAddresses() ([]netip.Prefix, error) {

// ForwardTo forwards all inbound traffic to the upstream network.
func (tun *TunDevice) ForwardTo(ctx context.Context, upstream network.Network) error {
return tun.ForwardToWithOptions(ctx, upstream, "")
}

// ForwardToWithOptions forwards all inbound traffic to the upstream network with optional port override.
func (tun *TunDevice) ForwardToWithOptions(ctx context.Context, upstream network.Network, overridePort string) error {
// Allow outgoing packets to have a source address different from the address
// assigned to the NIC.
if tcpipErr := tun.stack.SetSpoofing(tun.nicID, true); tcpipErr != nil {
Expand All @@ -307,10 +312,10 @@ func (tun *TunDevice) ForwardTo(ctx context.Context, upstream network.Network) e
return fmt.Errorf("failed to enable promiscuous mode: %v", tcpipErr)
}

tcpForwarder := TCPForwarder(ctx, tun.stack, upstream)
tcpForwarder := TCPForwarder(ctx, tun.stack, upstream, overridePort)
tun.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder)

udpForwarder := UDPForwarder(ctx, tun.stack, upstream)
udpForwarder := UDPForwarder(ctx, tun.stack, upstream, overridePort)
tun.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder)

return nil
Expand Down
24 changes: 17 additions & 7 deletions pkg/netstack/udp_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ var udpBuffPool = sync.Pool{
}

// UDPForwarder forwards UDP packets to an upstream network.
func UDPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network) ProtocolHandler {
// If overridePort is not empty, all packets will be forwarded to that port instead of the original destination port.
func UDPForwarder(ctx context.Context, ipstack *stack.Stack, upstream network.Network, overridePort string) ProtocolHandler {
udpForwarder := udp.NewForwarder(
ipstack,
udpHandler(ctx, upstream),
udpHandler(ctx, upstream, overridePort),
)

return udpForwarder.HandlePacket
Expand Down Expand Up @@ -77,15 +78,24 @@ func copyPackets(ctx context.Context, src, dst net.Conn, once bool, extend func(
}
}

func udpHandler(ctx context.Context, upstream network.Network) func(req *udp.ForwarderRequest) {
func udpHandler(ctx context.Context, upstream network.Network, overridePort string) func(req *udp.ForwarderRequest) {
return func(req *udp.ForwarderRequest) {
reqDetails := req.ID()

srcAddrPort := netip.AddrPortFrom(addrFromNetstackIP(reqDetails.RemoteAddress), reqDetails.RemotePort)
dstAddrPort := netip.AddrPortFrom(
Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress)),
reqDetails.LocalPort,
)

// Extract IPv4 address from IPv6-mapped address
dstAddr := Unmap4in6(addrFromNetstackIP(reqDetails.LocalAddress))
dstPort := reqDetails.LocalPort

// Override port if specified
if overridePort != "" {
if port, err := netip.ParseAddrPort("127.0.0.1:" + overridePort); err == nil {
dstPort = port.Port()
}
}

dstAddrPort := netip.AddrPortFrom(dstAddr, dstPort)

logger := slog.With(
slog.String("src", srcAddrPort.String()),
Expand Down
6 changes: 3 additions & 3 deletions pkg/netstack/udp_forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestUDPForwarder(t *testing.T) {
}()

// Setup the server stack to forward UDP packets to the hosts loopback interface.
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback()))
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), ""))

// Generate test data
testData := make([]byte, 1024)
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestUDPForwarderMultipleSessions(t *testing.T) {
}()

// Setup the server stack to forward UDP packets
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback()))
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), ""))

// Start multiple UDP servers on different ports
numServers := 3
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestUDPForwarderTimeout(t *testing.T) {
}()

// Setup the server stack to forward UDP packets
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback()))
serverStack.SetTransportProtocolHandler(udp.ProtocolNumber, netstack.UDPForwarder(ctx, serverStack.Stack, network.Loopback(), ""))

// Start a UDP server
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expand Down
12 changes: 12 additions & 0 deletions pkg/tunnel/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type tunnelClientOptions struct {
// Userspace options
socksListenAddr string
preserveDefaultGwDsts []netip.Prefix
overridePort string
}

func defaultClientOptions() *tunnelClientOptions {
Expand Down Expand Up @@ -141,6 +142,14 @@ func WithPreserveDefaultGatewayDestinations(dsts []netip.Prefix) TunnelClientOpt
}
}

// WithOverridePort sets the destination port override for forwarded packets.
// If empty string, the original destination port is preserved.
func WithOverridePort(port string) TunnelClientOption {
return func(o *tunnelClientOptions) {
o.overridePort = port
}
}

// BuildClientRouter builds a router for the client tunnel side using provided
// options and sane defaults.
func BuildClientRouter(opts ...TunnelClientOption) (router.Router, error) {
Expand Down Expand Up @@ -174,6 +183,9 @@ func BuildClientRouter(opts ...TunnelClientOption) (router.Router, error) {
if options.socksListenAddr != "" {
routerOpts = append(routerOpts, router.WithSocksListenAddr(options.socksListenAddr))
}
if options.overridePort != "" {
routerOpts = append(routerOpts, router.WithOverridePort(options.overridePort))
}

switch options.mode {
case TunnelClientModeKernel:
Expand Down
6 changes: 4 additions & 2 deletions pkg/tunnel/router/client_netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type NetstackRouter struct {
resolveConf *network.ResolveConfig
socksListenAddr string
cksumRecalc bool
overridePort string

closeOnce sync.Once
}
Expand Down Expand Up @@ -63,6 +64,7 @@ func NewNetstackRouter(opts ...Option) (*NetstackRouter, error) {
resolveConf: options.resolveConf,
socksListenAddr: options.socksListenAddr,
cksumRecalc: options.cksumRecalc,
overridePort: options.overridePort,
}, nil
}

Expand Down Expand Up @@ -99,10 +101,10 @@ func (r *NetstackRouter) Start(ctx context.Context) error {

slog.Info("Forwarding all inbound traffic to loopback interface")

if err := r.tunDev.ForwardTo(ctx, network.Filtered(&network.FilteredNetworkConfig{
if err := r.tunDev.ForwardToWithOptions(ctx, network.Filtered(&network.FilteredNetworkConfig{
DeniedPorts: []uint16{uint16(socksListenPort)},
Upstream: network.Host(),
})); err != nil {
}), r.overridePort); err != nil {
return fmt.Errorf("failed to forward to loopback: %w", err)
}

Expand Down
9 changes: 9 additions & 0 deletions pkg/tunnel/router/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type routerOptions struct {
sourcePortHashing bool
pc batchpc.BatchPacketConn
egressGateway bool
overridePort string
}

func defaultOptions() *routerOptions {
Expand Down Expand Up @@ -134,3 +135,11 @@ func WithEgressGateway(enable bool) Option {
o.egressGateway = enable
}
}

// WithOverridePort sets the destination port override for forwarded packets.
// If empty string, the original destination port is preserved.
func WithOverridePort(port string) Option {
return func(o *routerOptions) {
o.overridePort = port
}
}