From a68a8c7e245812806ad6d00345af2778c333806f Mon Sep 17 00:00:00 2001 From: Dream95 Date: Fri, 19 Sep 2025 12:04:15 +0000 Subject: [PATCH] feat: add ipmask Signed-off-by: Dream95 --- cmd/cmd.go | 21 ++----------------- cmd/loadBpf.go | 10 +++++---- cmd/proxy.c | 8 ++++++-- cmd/proxy_arm64_bpfel.go | 17 ++++++++-------- cmd/proxy_x86_bpfel.go | 17 ++++++++-------- common/ip.go | 44 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 41 deletions(-) create mode 100644 common/ip.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 005e41b..9b1b462 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,11 +1,9 @@ package main import ( - "encoding/binary" "fmt" "gotproxy/common" "log" - "net" "os" "strconv" @@ -39,11 +37,12 @@ var rootCmd = &cobra.Command{ return } - ip, err := ipStrToUnit32() + ip, mask, err := common.ParseIPWithMask(ipStr) if err != nil { log.Fatal(err) } Options.Ip4 = ip + Options.Ip4Mask = mask if proxyPid == 0 { StartProxy() @@ -60,22 +59,6 @@ var rootCmd = &cobra.Command{ }, } -func ipStrToUnit32() (uint32, error) { - if ipStr == "" { - return 0, nil - } - ip := net.ParseIP(ipStr) - if ip == nil { - return 0, fmt.Errorf("invalid ip: %s", ipStr) - } - - ip = ip.To4() - if ip == nil { - return 0, fmt.Errorf("is not ipV4: %s", ipStr) - } - return binary.LittleEndian.Uint32(ip), nil -} - func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Println(err) diff --git a/cmd/loadBpf.go b/cmd/loadBpf.go index 3af1833..cd3f530 100644 --- a/cmd/loadBpf.go +++ b/cmd/loadBpf.go @@ -30,6 +30,7 @@ type Options struct { Command string Pids []uint64 Ip4 uint32 + Ip4Mask uint8 } func LoadBpf(options *Options) { @@ -91,10 +92,11 @@ func LoadBpf(options *Options) { pid = options.ProxyPid } config := proxyConfig{ - ProxyPort: options.ProxyPort, - ProxyPid: pid, - FilterByPid: len(options.Pids) > 0, - FilterIp: options.Ip4, + ProxyPort: options.ProxyPort, + ProxyPid: pid, + FilterByPid: len(options.Pids) > 0, + FilterIp: options.Ip4, + FilterIpMask: options.Ip4Mask, } stringToInt8Array(config.Command[:], options.Command) err = objs.proxyMaps.MapConfig.Update(&key, &config, ebpf.UpdateAny) diff --git a/cmd/proxy.c b/cmd/proxy.c index 64ebe47..94d0b46 100644 --- a/cmd/proxy.c +++ b/cmd/proxy.c @@ -20,6 +20,7 @@ struct Config { __u16 proxy_port; __u64 proxy_pid; __u32 filter_ip; + __u8 filter_ip_mask; bool filter_by_pid; char command[TASK_COMM_LEN]; }; @@ -97,8 +98,11 @@ int cg_connect4(struct bpf_sock_addr *ctx) { if (!match_process(conf)) return 1; - if (conf->filter_ip){ - if (ctx->user_ip4 != conf->filter_ip) { + if (conf->filter_ip) + { + __u32 mask = 0xFFFFFFFF >> (32 - conf->filter_ip_mask); + if ((ctx->user_ip4 & mask) != (conf->filter_ip & mask)) + { BPF_LOG_DEBUG("not match ip\n"); return 1; } diff --git a/cmd/proxy_arm64_bpfel.go b/cmd/proxy_arm64_bpfel.go index 1202173..c395e54 100644 --- a/cmd/proxy_arm64_bpfel.go +++ b/cmd/proxy_arm64_bpfel.go @@ -14,14 +14,15 @@ import ( ) type proxyConfig struct { - _ structs.HostLayout - ProxyPort uint16 - _ [6]byte - ProxyPid uint64 - FilterIp uint32 - FilterByPid bool - Command [16]int8 - _ [3]byte + _ structs.HostLayout + ProxyPort uint16 + _ [6]byte + ProxyPid uint64 + FilterIp uint32 + FilterIpMask uint8 + FilterByPid bool + Command [16]int8 + _ [2]byte } type proxySocket struct { diff --git a/cmd/proxy_x86_bpfel.go b/cmd/proxy_x86_bpfel.go index 4a889fc..a714d14 100644 --- a/cmd/proxy_x86_bpfel.go +++ b/cmd/proxy_x86_bpfel.go @@ -14,14 +14,15 @@ import ( ) type proxyConfig struct { - _ structs.HostLayout - ProxyPort uint16 - _ [6]byte - ProxyPid uint64 - FilterIp uint32 - FilterByPid bool - Command [16]int8 - _ [3]byte + _ structs.HostLayout + ProxyPort uint16 + _ [6]byte + ProxyPid uint64 + FilterIp uint32 + FilterIpMask uint8 + FilterByPid bool + Command [16]int8 + _ [2]byte } type proxySocket struct { diff --git a/common/ip.go b/common/ip.go new file mode 100644 index 0000000..a2a2dee --- /dev/null +++ b/common/ip.go @@ -0,0 +1,44 @@ +package common + +import ( + "encoding/binary" + "fmt" + "net" + "strings" +) + +func ParseIPWithMask(ipStr string) (uint32, uint8, error) { + if ipStr == "" { + return 0, 0, nil + } + if strings.Contains(ipStr, "/") { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return 0, 0, fmt.Errorf("invalid CIDR: %s", ipStr) + } + + ip4 := ipNet.IP.To4() + if ip4 == nil { + return 0, 0, fmt.Errorf("not an IPv4 CIDR: %s", ipStr) + } + + maskSize, _ := ipNet.Mask.Size() + ipVal := binary.LittleEndian.Uint32(ip4) + return ipVal, uint8(maskSize), nil + } else { + ip := net.ParseIP(ipStr) + if ip == nil { + return 0, 0, fmt.Errorf("invalid IP: %s", ipStr) + } + + ip4 := ip.To4() + if ip4 == nil { + return 0, 0, fmt.Errorf("not an IPv4 address: %s", ipStr) + } + + ipVal := binary.LittleEndian.Uint32(ip4) + + fmt.Printf("Parsed IP: %s, Mask=32\n", ipStr) + return ipVal, 32, nil + } +}