From 8874f10d1708f0d8363b763d7520ffe5ff2e87b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 26 Dec 2025 00:40:56 +0800 Subject: [PATCH 01/13] Add packet reactor --- common/bufio/channel_demux.go | 143 +++ common/bufio/fd_demux_darwin.go | 225 ++++ common/bufio/fd_demux_linux.go | 217 ++++ common/bufio/fd_demux_stub.go | 25 + common/bufio/fd_demux_windows.go | 227 ++++ common/bufio/fd_demux_windows_test.go | 305 +++++ common/bufio/packet_reactor.go | 390 +++++++ common/bufio/packet_reactor_test.go | 1485 +++++++++++++++++++++++++ common/network/read_notifier.go | 21 + common/udpnat2/conn.go | 4 + common/wepoll/afd_windows.go | 122 ++ common/wepoll/pinner.go | 7 + common/wepoll/pinner_compat.go | 9 + common/wepoll/socket_windows.go | 49 + common/wepoll/syscall_windows.go | 8 + common/wepoll/types_windows.go | 64 ++ common/wepoll/wepoll_test.go | 335 ++++++ common/wepoll/zsyscall_windows.go | 84 ++ 18 files changed, 3720 insertions(+) create mode 100644 common/bufio/channel_demux.go create mode 100644 common/bufio/fd_demux_darwin.go create mode 100644 common/bufio/fd_demux_linux.go create mode 100644 common/bufio/fd_demux_stub.go create mode 100644 common/bufio/fd_demux_windows.go create mode 100644 common/bufio/fd_demux_windows_test.go create mode 100644 common/bufio/packet_reactor.go create mode 100644 common/bufio/packet_reactor_test.go create mode 100644 common/network/read_notifier.go create mode 100644 common/wepoll/afd_windows.go create mode 100644 common/wepoll/pinner.go create mode 100644 common/wepoll/pinner_compat.go create mode 100644 common/wepoll/socket_windows.go create mode 100644 common/wepoll/syscall_windows.go create mode 100644 common/wepoll/types_windows.go create mode 100644 common/wepoll/wepoll_test.go create mode 100644 common/wepoll/zsyscall_windows.go diff --git a/common/bufio/channel_demux.go b/common/bufio/channel_demux.go new file mode 100644 index 00000000..2876083c --- /dev/null +++ b/common/bufio/channel_demux.go @@ -0,0 +1,143 @@ +package bufio + +import ( + "context" + "reflect" + "sync" + "sync/atomic" + + N "github.com/sagernet/sing/common/network" +) + +type channelDemuxEntry struct { + channel <-chan *N.PacketBuffer + stream *reactorStream +} + +type ChannelDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + mutex sync.Mutex + entries map[<-chan *N.PacketBuffer]*channelDemuxEntry + updateChan chan struct{} + running bool + closed atomic.Bool + wg sync.WaitGroup +} + +func NewChannelDemultiplexer(ctx context.Context) *ChannelDemultiplexer { + ctx, cancel := context.WithCancel(ctx) + demux := &ChannelDemultiplexer{ + ctx: ctx, + cancel: cancel, + entries: make(map[<-chan *N.PacketBuffer]*channelDemuxEntry), + updateChan: make(chan struct{}, 1), + } + return demux +} + +func (d *ChannelDemultiplexer) Add(stream *reactorStream, channel <-chan *N.PacketBuffer) { + d.mutex.Lock() + + if d.closed.Load() { + d.mutex.Unlock() + return + } + + entry := &channelDemuxEntry{ + channel: channel, + stream: stream, + } + d.entries[channel] = entry + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + d.mutex.Unlock() + d.signalUpdate() +} + +func (d *ChannelDemultiplexer) Remove(channel <-chan *N.PacketBuffer) { + d.mutex.Lock() + delete(d.entries, channel) + d.mutex.Unlock() + d.signalUpdate() +} + +func (d *ChannelDemultiplexer) signalUpdate() { + select { + case d.updateChan <- struct{}{}: + default: + } +} + +func (d *ChannelDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.signalUpdate() + d.wg.Wait() + return nil +} + +func (d *ChannelDemultiplexer) run() { + defer d.wg.Done() + + for { + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + + cases := make([]reflect.SelectCase, 0, len(d.entries)+2) + + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(d.ctx.Done()), + }) + + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(d.updateChan), + }) + + entryList := make([]*channelDemuxEntry, 0, len(d.entries)) + for _, entry := range d.entries { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(entry.channel), + }) + entryList = append(entryList, entry) + } + d.mutex.Unlock() + + chosen, recv, recvOK := reflect.Select(cases) + + switch chosen { + case 0: + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + case 1: + continue + default: + entry := entryList[chosen-2] + d.mutex.Lock() + delete(d.entries, entry.channel) + d.mutex.Unlock() + + if recvOK { + packet := recv.Interface().(*N.PacketBuffer) + go entry.stream.runActiveLoop(packet) + } else { + go entry.stream.closeWithError(nil) + } + } + } +} diff --git a/common/bufio/fd_demux_darwin.go b/common/bufio/fd_demux_darwin.go new file mode 100644 index 00000000..3e1c7876 --- /dev/null +++ b/common/bufio/fd_demux_darwin.go @@ -0,0 +1,225 @@ +//go:build darwin + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + stream *reactorStream +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + kqueueFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + kqueueFD, err := unix.Kqueue() + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe(pipeFDs[:]) + if err != nil { + unix.Close(kqueueFD) + return nil, err + } + + err = unix.SetNonblock(pipeFDs[0], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + err = unix.SetNonblock(pipeFDs[1], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + _, err = unix.Kevent(kqueueFD, []unix.Kevent_t{{ + Ident: uint64(pipeFDs[0]), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD, + }}, nil, nil) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + kqueueFD: kqueueFD, + entries: make(map[int]*fdDemuxEntry), + pipeFDs: pipeFDs, + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return unix.EINVAL + } + + _, err := unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD, + }}, nil, nil) + if err != nil { + return err + } + + entry := &fdDemuxEntry{ + fd: fd, + stream: stream, + } + d.entries[fd] = entry + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + _, ok := d.entries[fd] + if !ok { + return + } + + unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_DELETE, + }}, nil, nil) + delete(d.entries, fd) +} + +func (d *FDDemultiplexer) wakeup() { + unix.Write(d.pipeFDs[1], []byte{0}) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.kqueueFD != -1 { + unix.Close(d.kqueueFD) + d.kqueueFD = -1 + } + if d.pipeFDs[0] != -1 { + unix.Close(d.pipeFDs[0]) + unix.Close(d.pipeFDs[1]) + d.pipeFDs[0] = -1 + d.pipeFDs[1] = -1 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + events := make([]unix.Kevent_t, 64) + var buffer [1]byte + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + n, err := unix.Kevent(d.kqueueFD, nil, events, nil) + if err != nil { + if err == unix.EINTR { + continue + } + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + fd := int(event.Ident) + + if fd == d.pipeFDs[0] { + unix.Read(d.pipeFDs[0], buffer[:]) + continue + } + + if event.Flags&unix.EV_ERROR != 0 { + continue + } + + d.mutex.Lock() + entry, ok := d.entries[fd] + if !ok { + d.mutex.Unlock() + continue + } + + unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_DELETE, + }}, nil, nil) + delete(d.entries, fd) + d.mutex.Unlock() + + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_linux.go b/common/bufio/fd_demux_linux.go new file mode 100644 index 00000000..2c5e0afa --- /dev/null +++ b/common/bufio/fd_demux_linux.go @@ -0,0 +1,217 @@ +//go:build linux + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + registrationID uint64 + stream *reactorStream +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + epollFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + registrationToFD map[uint64]int + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + epollFD, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe2(pipeFDs[:], unix.O_NONBLOCK|unix.O_CLOEXEC) + if err != nil { + unix.Close(epollFD) + return nil, err + } + + pipeEvent := &unix.EpollEvent{Events: unix.EPOLLIN} + *(*uint64)(unsafe.Pointer(&pipeEvent.Fd)) = 0 + err = unix.EpollCtl(epollFD, unix.EPOLL_CTL_ADD, pipeFDs[0], pipeEvent) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(epollFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + epollFD: epollFD, + entries: make(map[int]*fdDemuxEntry), + registrationToFD: make(map[uint64]int), + pipeFDs: pipeFDs, + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return unix.EINVAL + } + + d.registrationCounter++ + registrationID := d.registrationCounter + + event := &unix.EpollEvent{Events: unix.EPOLLIN | unix.EPOLLRDHUP} + *(*uint64)(unsafe.Pointer(&event.Fd)) = registrationID + + err := unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_ADD, fd, event) + if err != nil { + return err + } + + entry := &fdDemuxEntry{ + fd: fd, + registrationID: registrationID, + stream: stream, + } + d.entries[fd] = entry + d.registrationToFD[registrationID] = fd + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + entry, ok := d.entries[fd] + if !ok { + return + } + + unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(d.registrationToFD, entry.registrationID) + delete(d.entries, fd) +} + +func (d *FDDemultiplexer) wakeup() { + unix.Write(d.pipeFDs[1], []byte{0}) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.epollFD != -1 { + unix.Close(d.epollFD) + d.epollFD = -1 + } + if d.pipeFDs[0] != -1 { + unix.Close(d.pipeFDs[0]) + unix.Close(d.pipeFDs[1]) + d.pipeFDs[0] = -1 + d.pipeFDs[1] = -1 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + events := make([]unix.EpollEvent, 64) + var buffer [1]byte + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + n, err := unix.EpollWait(d.epollFD, events, -1) + if err != nil { + if err == unix.EINTR { + continue + } + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + registrationID := *(*uint64)(unsafe.Pointer(&event.Fd)) + + if registrationID == 0 { + unix.Read(d.pipeFDs[0], buffer[:]) + continue + } + + if event.Events&(unix.EPOLLIN|unix.EPOLLRDHUP|unix.EPOLLHUP|unix.EPOLLERR) == 0 { + continue + } + + d.mutex.Lock() + fd, ok := d.registrationToFD[registrationID] + if !ok { + d.mutex.Unlock() + continue + } + + entry := d.entries[fd] + if entry == nil || entry.registrationID != registrationID { + d.mutex.Unlock() + continue + } + + unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(d.registrationToFD, registrationID) + delete(d.entries, fd) + d.mutex.Unlock() + + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_stub.go b/common/bufio/fd_demux_stub.go new file mode 100644 index 00000000..d2248cf5 --- /dev/null +++ b/common/bufio/fd_demux_stub.go @@ -0,0 +1,25 @@ +//go:build !linux && !darwin && !windows + +package bufio + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" +) + +type FDDemultiplexer struct{} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + return nil, E.New("FDDemultiplexer not supported on this platform") +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + return E.New("FDDemultiplexer not supported on this platform") +} + +func (d *FDDemultiplexer) Remove(fd int) {} + +func (d *FDDemultiplexer) Close() error { + return nil +} diff --git a/common/bufio/fd_demux_windows.go b/common/bufio/fd_demux_windows.go new file mode 100644 index 00000000..06795ebe --- /dev/null +++ b/common/bufio/fd_demux_windows.go @@ -0,0 +1,227 @@ +//go:build windows + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "github.com/sagernet/sing/common/wepoll" + + "golang.org/x/sys/windows" +) + +type fdDemuxEntry struct { + ioStatusBlock windows.IO_STATUS_BLOCK + pollInfo wepoll.AFDPollInfo + stream *reactorStream + fd int + handle windows.Handle + baseHandle windows.Handle + registrationID uint64 + cancelled bool + pinner wepoll.Pinner +} + +type FDDemultiplexer struct { + ctx context.Context + cancel context.CancelFunc + iocp windows.Handle + afd *wepoll.AFD + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + running bool + closed atomic.Bool + wg sync.WaitGroup +} + +func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return nil, err + } + + afd, err := wepoll.NewAFD(iocp, "Go") + if err != nil { + windows.CloseHandle(iocp) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + demux := &FDDemultiplexer{ + ctx: ctx, + cancel: cancel, + iocp: iocp, + afd: afd, + entries: make(map[int]*fdDemuxEntry), + } + return demux, nil +} + +func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.closed.Load() { + return windows.ERROR_INVALID_HANDLE + } + + handle := windows.Handle(fd) + baseHandle, err := wepoll.GetBaseSocket(handle) + if err != nil { + return err + } + + d.registrationCounter++ + registrationID := d.registrationCounter + + entry := &fdDemuxEntry{ + stream: stream, + fd: fd, + handle: handle, + baseHandle: baseHandle, + registrationID: registrationID, + } + + entry.pinner.Pin(entry) + + events := uint32(wepoll.AFD_POLL_RECEIVE | wepoll.AFD_POLL_DISCONNECT | wepoll.AFD_POLL_ABORT | wepoll.AFD_POLL_LOCAL_CLOSE) + err = d.afd.Poll(baseHandle, events, &entry.ioStatusBlock, &entry.pollInfo) + if err != nil { + entry.pinner.Unpin() + return err + } + + d.entries[fd] = entry + + if !d.running { + d.running = true + d.wg.Add(1) + go d.run() + } + + return nil +} + +func (d *FDDemultiplexer) Remove(fd int) { + d.mutex.Lock() + defer d.mutex.Unlock() + + entry, ok := d.entries[fd] + if !ok { + return + } + + entry.cancelled = true + if d.afd != nil { + d.afd.Cancel(&entry.ioStatusBlock) + } +} + +func (d *FDDemultiplexer) wakeup() { + windows.PostQueuedCompletionStatus(d.iocp, 0, 0, nil) +} + +func (d *FDDemultiplexer) Close() error { + d.mutex.Lock() + d.closed.Store(true) + d.mutex.Unlock() + + d.cancel() + d.wakeup() + d.wg.Wait() + + d.mutex.Lock() + defer d.mutex.Unlock() + + for fd, entry := range d.entries { + entry.pinner.Unpin() + delete(d.entries, fd) + } + + if d.afd != nil { + d.afd.Close() + d.afd = nil + } + if d.iocp != 0 { + windows.CloseHandle(d.iocp) + d.iocp = 0 + } + return nil +} + +func (d *FDDemultiplexer) run() { + defer d.wg.Done() + + completions := make([]wepoll.OverlappedEntry, 64) + + for { + select { + case <-d.ctx.Done(): + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + default: + } + + var numRemoved uint32 + err := wepoll.GetQueuedCompletionStatusEx(d.iocp, &completions[0], 64, &numRemoved, windows.INFINITE, false) + if err != nil { + d.mutex.Lock() + d.running = false + d.mutex.Unlock() + return + } + + for i := uint32(0); i < numRemoved; i++ { + event := completions[i] + + if event.Overlapped == nil { + continue + } + + entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) + + d.mutex.Lock() + + if d.entries[entry.fd] != entry { + d.mutex.Unlock() + continue + } + + entry.pinner.Unpin() + delete(d.entries, entry.fd) + + if entry.cancelled { + d.mutex.Unlock() + continue + } + + if uint32(entry.ioStatusBlock.Status) == wepoll.STATUS_CANCELLED { + d.mutex.Unlock() + continue + } + + events := entry.pollInfo.Handles[0].Events + if events&(wepoll.AFD_POLL_RECEIVE|wepoll.AFD_POLL_DISCONNECT|wepoll.AFD_POLL_ABORT|wepoll.AFD_POLL_LOCAL_CLOSE) == 0 { + d.mutex.Unlock() + continue + } + + d.mutex.Unlock() + go entry.stream.runActiveLoop(nil) + } + + d.mutex.Lock() + if len(d.entries) == 0 { + d.running = false + d.mutex.Unlock() + return + } + d.mutex.Unlock() + } +} diff --git a/common/bufio/fd_demux_windows_test.go b/common/bufio/fd_demux_windows_test.go new file mode 100644 index 00000000..030a3dfb --- /dev/null +++ b/common/bufio/fd_demux_windows_test.go @@ -0,0 +1,305 @@ +//go:build windows + +package bufio + +import ( + "context" + "net" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func getSocketFD(t *testing.T, conn net.PacketConn) int { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return fd +} + +func TestFDDemultiplexer_Create(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) +} + +func TestFDDemultiplexer_CreateMultiple(t *testing.T) { + t.Parallel() + + demux1, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux1.Close() + + demux2, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux2.Close() +} + +func TestFDDemultiplexer_AddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) +} + +func TestFDDemultiplexer_RapidAddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const iterations = 50 + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) + conn.Close() + } +} + +func TestFDDemultiplexer_ConcurrentAccess(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numGoroutines = 10 + const iterations = 20 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + continue + } + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + if err == nil { + demux.Remove(fd) + } + conn.Close() + } + }() + } + + wg.Wait() +} + +func TestFDDemultiplexer_ReceiveEvent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + demux, err := NewFDDemultiplexer(ctx) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + triggered := make(chan struct{}, 1) + stream := &reactorStream{ + state: atomic.Int32{}, + } + stream.connection = &reactorConnection{ + upload: stream, + download: stream, + done: make(chan struct{}), + } + + originalRunActiveLoop := stream.runActiveLoop + _ = originalRunActiveLoop + + err = demux.Add(stream, fd) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + select { + case <-triggered: + default: + } + + demux.Remove(fd) +} + +func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked - possible deadlock") + } +} + +func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + demux.Remove(99999) +} + +func TestFDDemultiplexer_AddAfterClose(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.Error(t, err) +} + +func TestFDDemultiplexer_MultipleSocketsSimultaneous(t *testing.T) { + t.Parallel() + + demux, err := NewFDDemultiplexer(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numSockets = 5 + conns := make([]net.PacketConn, numSockets) + fds := make([]int, numSockets) + + for i := 0; i < numSockets; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + fd := getSocketFD(t, conn) + fds[i] = fd + + stream := &reactorStream{} + err = demux.Add(stream, fd) + require.NoError(t, err) + } + + for i := 0; i < numSockets; i++ { + demux.Remove(fds[i]) + } +} + +func TestFDDemultiplexer_ContextCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + demux, err := NewFDDemultiplexer(ctx) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + cancel() + + time.Sleep(100 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked after context cancellation") + } +} diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go new file mode 100644 index 00000000..4ded4ed6 --- /dev/null +++ b/common/bufio/packet_reactor.go @@ -0,0 +1,390 @@ +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const ( + batchReadTimeout = 250 * time.Millisecond +) + +const ( + stateIdle int32 = 0 + stateActive int32 = 1 + stateClosed int32 = 2 +) + +type PacketReactor struct { + ctx context.Context + cancel context.CancelFunc + channelDemux *ChannelDemultiplexer + fdDemux *FDDemultiplexer + fdDemuxOnce sync.Once + fdDemuxErr error +} + +func NewPacketReactor(ctx context.Context) *PacketReactor { + ctx, cancel := context.WithCancel(ctx) + return &PacketReactor{ + ctx: ctx, + cancel: cancel, + channelDemux: NewChannelDemultiplexer(ctx), + } +} + +func (r *PacketReactor) getFDDemultiplexer() (*FDDemultiplexer, error) { + r.fdDemuxOnce.Do(func() { + r.fdDemux, r.fdDemuxErr = NewFDDemultiplexer(r.ctx) + }) + return r.fdDemux, r.fdDemuxErr +} + +func (r *PacketReactor) Close() error { + r.cancel() + var errs []error + if r.channelDemux != nil { + errs = append(errs, r.channelDemux.Close()) + } + if r.fdDemux != nil { + errs = append(errs, r.fdDemux.Close()) + } + return E.Errors(errs...) +} + +type reactorConnection struct { + ctx context.Context + cancel context.CancelFunc + reactor *PacketReactor + onClose N.CloseHandlerFunc + upload *reactorStream + download *reactorStream + + closeOnce sync.Once + done chan struct{} + err error +} + +type reactorStream struct { + connection *reactorConnection + + source N.PacketReader + destination N.PacketWriter + originSource N.PacketReader + + notifier N.ReadNotifier + options N.ReadWaitOptions + readWaiter N.PacketReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc + + state atomic.Int32 +} + +func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destination N.PacketConn, onClose N.CloseHandlerFunc) { + ctx, cancel := context.WithCancel(ctx) + conn := &reactorConnection{ + ctx: ctx, + cancel: cancel, + reactor: r, + onClose: onClose, + done: make(chan struct{}), + } + + conn.upload = r.prepareStream(conn, source, destination) + select { + case <-conn.done: + return + default: + } + + conn.download = r.prepareStream(conn, destination, source) + select { + case <-conn.done: + return + default: + } + + r.registerStream(conn.upload) + r.registerStream(conn.download) +} + +func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketReader, destination N.PacketWriter) *reactorStream { + stream := &reactorStream{ + connection: conn, + source: source, + destination: destination, + originSource: source, + } + + for { + source, stream.readCounters = N.UnwrapCountPacketReader(source, stream.readCounters) + destination, stream.writeCounters = N.UnwrapCountPacketWriter(destination, stream.writeCounters) + if cachedReader, isCached := source.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + dataLen := packet.Buffer.Len() + err := destination.WritePacket(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + if err != nil { + conn.closeWithError(err) + return stream + } + for _, counter := range stream.readCounters { + counter(int64(dataLen)) + } + for _, counter := range stream.writeCounters { + counter(int64(dataLen)) + } + continue + } + } + break + } + stream.source = source + stream.destination = destination + + stream.options = N.NewReadWaitOptions(source, destination) + + stream.readWaiter, _ = CreatePacketReadWaiter(source) + if stream.readWaiter != nil { + stream.readWaiter.InitializeReadWaiter(stream.options) + } + + if notifierSource, ok := source.(N.ReadNotifierSource); ok { + stream.notifier = notifierSource.CreateReadNotifier() + } + + return stream +} + +func (r *PacketReactor) registerStream(stream *reactorStream) { + if stream.notifier == nil { + go stream.runLegacyCopy() + return + } + + switch notifier := stream.notifier.(type) { + case *N.ChannelNotifier: + r.channelDemux.Add(stream, notifier.Channel) + case *N.FileDescriptorNotifier: + fdDemux, err := r.getFDDemultiplexer() + if err != nil { + go stream.runLegacyCopy() + return + } + err = fdDemux.Add(stream, notifier.FD) + if err != nil { + go stream.runLegacyCopy() + } + default: + go stream.runLegacyCopy() + } +} + +func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { + if s.source == nil { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + if !s.state.CompareAndSwap(stateIdle, stateActive) { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + + notFirstTime := false + + if firstPacket != nil { + err := s.writePacketWithCounters(firstPacket) + if err != nil { + s.closeWithError(err) + return + } + notFirstTime = true + } + + for { + if s.state.Load() == stateClosed { + return + } + + if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Now().Add(batchReadTimeout)) + } + + var ( + buffer *N.PacketBuffer + destination M.Socksaddr + err error + ) + + if s.readWaiter != nil { + var readBuffer *buf.Buffer + readBuffer, destination, err = s.readWaiter.WaitReadPacket() + if readBuffer != nil { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } else { + readBuffer := s.options.NewPacketBuffer() + destination, err = s.source.ReadPacket(readBuffer) + if err != nil { + readBuffer.Release() + } else { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } + + if err != nil { + if E.IsTimeout(err) { + if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Time{}) + } + if s.state.CompareAndSwap(stateActive, stateIdle) { + s.returnToPool() + } + return + } + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.closeWithError(err) + return + } + + err = s.writePacketWithCounters(buffer) + if err != nil { + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.closeWithError(err) + return + } + notFirstTime = true + } +} + +func (s *reactorStream) writePacketWithCounters(packet *N.PacketBuffer) error { + buffer := packet.Buffer + destination := packet.Destination + dataLen := buffer.Len() + + s.options.PostReturn(buffer) + err := s.destination.WritePacket(buffer, destination) + N.PutPacketBuffer(packet) + if err != nil { + buffer.Leak() + return err + } + + for _, counter := range s.readCounters { + counter(int64(dataLen)) + } + for _, counter := range s.writeCounters { + counter(int64(dataLen)) + } + return nil +} + +func (s *reactorStream) returnToPool() { + if s.state.Load() != stateIdle { + return + } + + switch notifier := s.notifier.(type) { + case *N.ChannelNotifier: + s.connection.reactor.channelDemux.Add(s, notifier.Channel) + if s.state.Load() != stateIdle { + s.connection.reactor.channelDemux.Remove(notifier.Channel) + } + case *N.FileDescriptorNotifier: + if s.connection.reactor.fdDemux != nil { + err := s.connection.reactor.fdDemux.Add(s, notifier.FD) + if err != nil { + s.closeWithError(err) + return + } + if s.state.Load() != stateIdle { + s.connection.reactor.fdDemux.Remove(notifier.FD) + } + } + } +} + +func (s *reactorStream) runLegacyCopy() { + _, err := CopyPacket(s.destination, s.source) + s.closeWithError(err) +} + +func (s *reactorStream) closeWithError(err error) { + s.connection.closeWithError(err) +} + +func (c *reactorConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.err = err + c.cancel() + + if c.upload != nil { + c.upload.state.Store(stateClosed) + } + if c.download != nil { + c.download.state.Store(stateClosed) + } + + c.removeFromDemultiplexers() + + if c.upload != nil { + common.Close(c.upload.originSource) + } + if c.download != nil { + common.Close(c.download.originSource) + } + + if c.onClose != nil { + c.onClose(c.err) + } + + close(c.done) + }) +} + +func (c *reactorConnection) removeFromDemultiplexers() { + if c.upload != nil && c.upload.notifier != nil { + switch notifier := c.upload.notifier.(type) { + case *N.ChannelNotifier: + c.reactor.channelDemux.Remove(notifier.Channel) + case *N.FileDescriptorNotifier: + if c.reactor.fdDemux != nil { + c.reactor.fdDemux.Remove(notifier.FD) + } + } + } + if c.download != nil && c.download.notifier != nil { + switch notifier := c.download.notifier.(type) { + case *N.ChannelNotifier: + c.reactor.channelDemux.Remove(notifier.Channel) + case *N.FileDescriptorNotifier: + if c.reactor.fdDemux != nil { + c.reactor.fdDemux.Remove(notifier.FD) + } + } + } +} diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go new file mode 100644 index 00000000..ae155130 --- /dev/null +++ b/common/bufio/packet_reactor_test.go @@ -0,0 +1,1485 @@ +//go:build darwin || linux || windows + +package bufio + +import ( + "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testPacketPipe struct { + inChan chan *N.PacketBuffer + outChan chan *N.PacketBuffer + localAddr M.Socksaddr + closed atomic.Bool + closeOnce sync.Once + done chan struct{} +} + +func newTestPacketPipe(localAddr M.Socksaddr) *testPacketPipe { + return &testPacketPipe{ + inChan: make(chan *N.PacketBuffer, 256), + outChan: make(chan *N.PacketBuffer, 256), + localAddr: localAddr, + done: make(chan struct{}), + } +} + +func (p *testPacketPipe) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case packet, ok := <-p.inChan: + if !ok { + return M.Socksaddr{}, io.EOF + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err + case <-p.done: + return M.Socksaddr{}, net.ErrClosed + } +} + +func (p *testPacketPipe) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if p.closed.Load() { + buffer.Release() + return net.ErrClosed + } + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(buffer.Len()) + newBuf.Write(buffer.Bytes()) + packet.Buffer = newBuf + packet.Destination = destination + buffer.Release() + select { + case p.outChan <- packet: + return nil + case <-p.done: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return net.ErrClosed + } +} + +func (p *testPacketPipe) Close() error { + p.closeOnce.Do(func() { + p.closed.Store(true) + close(p.done) + }) + return nil +} + +func (p *testPacketPipe) LocalAddr() net.Addr { + return p.localAddr.UDPAddr() +} + +func (p *testPacketPipe) SetDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetReadDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetWriteDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: p.inChan} +} + +func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) { + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(len(data)) + newBuf.Write(data) + packet.Buffer = newBuf + packet.Destination = destination + p.inChan <- packet +} + +func (p *testPacketPipe) receive() (*N.PacketBuffer, bool) { + select { + case packet, ok := <-p.outChan: + return packet, ok + case <-p.done: + return nil, false + } +} + +type fdPacketConn struct { + N.NetPacketConn + fd int + targetAddr M.Socksaddr +} + +func newFDPacketConn(t *testing.T, conn net.PacketConn, targetAddr M.Socksaddr) *fdPacketConn { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok, "connection must implement syscall.Conn") + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return &fdPacketConn{ + NetPacketConn: NewPacketConn(conn), + fd: fd, + targetAddr: targetAddr, + } +} + +func (c *fdPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} + +func (c *fdPacketConn) CreateReadNotifier() N.ReadNotifier { + return &N.FileDescriptorNotifier{FD: c.fd} +} + +type channelPacketConn struct { + N.NetPacketConn + packetChan chan *N.PacketBuffer + done chan struct{} + closeOnce sync.Once + targetAddr M.Socksaddr + deadlineLock sync.Mutex + deadline time.Time + deadlineChan chan struct{} +} + +func newChannelPacketConn(conn net.PacketConn, targetAddr M.Socksaddr) *channelPacketConn { + c := &channelPacketConn{ + NetPacketConn: NewPacketConn(conn), + packetChan: make(chan *N.PacketBuffer, 256), + done: make(chan struct{}), + targetAddr: targetAddr, + deadlineChan: make(chan struct{}), + } + go c.readLoop() + return c +} + +func (c *channelPacketConn) readLoop() { + for { + select { + case <-c.done: + return + default: + } + buffer := buf.NewPacket() + _, err := c.NetPacketConn.ReadPacket(buffer) + if err != nil { + buffer.Release() + close(c.packetChan) + return + } + packet := N.NewPacketBuffer() + packet.Buffer = buffer + packet.Destination = c.targetAddr + select { + case c.packetChan <- packet: + case <-c.done: + buffer.Release() + N.PutPacketBuffer(packet) + return + } + } +} + +func (c *channelPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + c.deadlineLock.Lock() + deadline := c.deadline + deadlineChan := c.deadlineChan + c.deadlineLock.Unlock() + + var timer <-chan time.Time + if !deadline.IsZero() { + d := time.Until(deadline) + if d <= 0 { + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + t := time.NewTimer(d) + defer t.Stop() + timer = t.C + } + + select { + case packet, ok := <-c.packetChan: + if !ok { + return M.Socksaddr{}, net.ErrClosed + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + case <-c.done: + return M.Socksaddr{}, net.ErrClosed + case <-deadlineChan: + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-timer: + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *channelPacketConn) SetReadDeadline(t time.Time) error { + c.deadlineLock.Lock() + c.deadline = t + if c.deadlineChan != nil { + close(c.deadlineChan) + } + c.deadlineChan = make(chan struct{}) + c.deadlineLock.Unlock() + return nil +} + +func (c *channelPacketConn) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: c.packetChan} +} + +func (c *channelPacketConn) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return c.NetPacketConn.Close() +} + +type batchHashPair struct { + sendHash map[int][]byte + recvHash map[int][]byte +} + +func TestBatchCopy_Pipe_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, sendHash, recvHash, "data mismatch") +} + +func TestBatchCopy_Pipe_Bidirectional(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + packet, ok := pipeA.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeB.send(data, addr1) + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var aPair, bPair batchHashPair + select { + case aPair = <-pingCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for A") + } + select { + case bPair = <-pongCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for B") + } + + assert.Equal(t, aPair.sendHash, bPair.recvHash, "A->B mismatch") + assert.Equal(t, bPair.sendHash, aPair.recvHash, "B->A mismatch") +} + +func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newFDPacketConn(t, proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_ChannelPoller_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newChannelPacketConn(proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MixedMode_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MultipleConnections_DataIntegrity(t *testing.T) { + t.Parallel() + + const numConnections = 5 + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var wg sync.WaitGroup + errCh := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20000+idx*2)) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20001+idx*2)) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 20 + const chunkSize = 1000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + errCh <- errors.New("timeout") + return + } + + for k, v := range sendHash { + if rv, ok := recvHash[k]; !ok || string(v) != string(rv) { + errCh <- errors.New("data mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } +} + +func TestBatchCopy_TimeoutAndResume_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 30001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 30002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sendAndVerify := func(batchID int, count int) { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < count; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(1))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < count; i++ { + data := make([]byte, 1000) + rand.Read(data[2:]) + data[0] = byte(batchID) + data[1] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(5 * time.Second): + t.Fatalf("batch %d timeout", batchID) + } + + assert.Equal(t, sendHash, recvHash, "batch %d mismatch", batchID) + } + + sendAndVerify(1, 10) + + time.Sleep(350 * time.Millisecond) + + sendAndVerify(2, 10) +} + +func TestBatchCopy_CloseWhileTransferring(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 40001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 40002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + stopSend := make(chan struct{}) + go func() { + for { + select { + case <-stopSend: + return + default: + data := make([]byte, 1000) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + close(stopSend) + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("copier did not close - possible deadlock") + } +} + +func TestBatchCopy_HighThroughput(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 50001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 50002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 500 + const chunkSize = 8000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + var mu sync.Mutex + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + mu.Lock() + recvHash[idx] = hash[:] + mu.Unlock() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[2:]) + data[0] = byte(i >> 8) + data[1] = byte(i & 0xff) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(30 * time.Second): + t.Fatal("high throughput test timeout") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, len(sendHash), len(recvHash), "packet count mismatch") + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_LegacyFallback_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyAConn), targetAddr: serverAddr} + proxyB := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyBConn), targetAddr: clientAddr} + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + clientConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("client read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + serverConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("server read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(20 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_ReactorClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + go func() { + for { + select { + case <-copyDone: + return + default: + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(10 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after reactor close") + } +} + +func TestBatchCopy_SmallPackets(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60011) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60012) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const totalPackets = 20 + receivedCount := 0 + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + receivedCount++ + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < totalPackets; i++ { + size := (i % 10) + 1 + data := make([]byte, size) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for packets") + } + + assert.Equal(t, totalPackets, receivedCount) +} + +func TestBatchCopy_VaryingPacketSizes(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60041) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60042) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sizes := []int{10, 100, 500, 1000, 2000, 4000, 8000} + const times = 3 + + totalPackets := len(sizes) * times + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[idx] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + packetIdx := 0 + for _, size := range sizes { + for j := 0; j < times; j++ { + data := make([]byte, size) + rand.Read(data[2:]) + data[0] = byte(packetIdx >> 8) + data[1] = byte(packetIdx & 0xff) + hash := md5.Sum(data) + sendHash[packetIdx] = hash[:] + pipeA.send(data, addr2) + packetIdx++ + time.Sleep(5 * time.Millisecond) + } + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + + assert.Equal(t, len(sendHash), len(recvHash)) + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_OnCloseCallback(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60021) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60022) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + callbackCalled := make(chan error, 1) + onClose := func(err error) { + select { + case callbackCalled <- err: + default: + } + } + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + t.Fatal("onClose callback was not called") + } +} + +func TestBatchCopy_SourceClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60031) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60032) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + callbackCalled := make(chan struct{}) + onClose := func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(callbackCalled) + } + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + close(pipeA.inChan) + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + pipeB.Close() + t.Fatal("onClose callback was not called after source close") + } + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after source close") + } + + pipeB.Close() + + errMu.Lock() + err := capturedErr + errMu.Unlock() + + require.NotNil(t, err) +} + +type legacyPacketConn struct { + N.NetPacketConn + targetAddr M.Socksaddr +} + +func (c *legacyPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} diff --git a/common/network/read_notifier.go b/common/network/read_notifier.go new file mode 100644 index 00000000..3a693b15 --- /dev/null +++ b/common/network/read_notifier.go @@ -0,0 +1,21 @@ +package network + +type ReadNotifier interface { + isReadNotifier() +} + +type ChannelNotifier struct { + Channel <-chan *PacketBuffer +} + +func (*ChannelNotifier) isReadNotifier() {} + +type FileDescriptorNotifier struct { + FD int +} + +func (*FileDescriptorNotifier) isReadNotifier() {} + +type ReadNotifierSource interface { + CreateReadNotifier() ReadNotifier +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 3c0cda38..81710bfc 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -140,3 +140,7 @@ func (c *natConn) SetWriteDeadline(t time.Time) error { func (c *natConn) Upstream() any { return c.writer } + +func (c *natConn) CreateReadNotifier() N.ReadNotifier { + return &N.ChannelNotifier{Channel: c.packetChan} +} diff --git a/common/wepoll/afd_windows.go b/common/wepoll/afd_windows.go new file mode 100644 index 00000000..aad3391a --- /dev/null +++ b/common/wepoll/afd_windows.go @@ -0,0 +1,122 @@ +//go:build windows + +package wepoll + +import ( + "math" + "unsafe" + + "golang.org/x/sys/windows" +) + +type AFD struct { + handle windows.Handle +} + +func NewAFD(iocp windows.Handle, name string) (*AFD, error) { + deviceName := `\Device\Afd\` + name + deviceNameUTF16, err := windows.UTF16FromString(deviceName) + if err != nil { + return nil, err + } + + unicodeString := UnicodeString{ + Length: uint16(len(deviceName) * 2), + MaximumLength: uint16(len(deviceName) * 2), + Buffer: &deviceNameUTF16[0], + } + + objectAttributes := ObjectAttributes{ + Length: uint32(unsafe.Sizeof(ObjectAttributes{})), + ObjectName: &unicodeString, + Attributes: OBJ_CASE_INSENSITIVE, + } + + var handle windows.Handle + var ioStatusBlock windows.IO_STATUS_BLOCK + + err = NtCreateFile( + &handle, + windows.SYNCHRONIZE, + &objectAttributes, + &ioStatusBlock, + nil, + 0, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, + FILE_OPEN, + 0, + 0, + 0, + ) + if err != nil { + return nil, err + } + + _, err = windows.CreateIoCompletionPort(handle, iocp, 0, 0) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + err = windows.SetFileCompletionNotificationModes(handle, windows.FILE_SKIP_SET_EVENT_ON_HANDLE) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + return &AFD{handle: handle}, nil +} + +func (a *AFD) Poll(baseSocket windows.Handle, events uint32, iosb *windows.IO_STATUS_BLOCK, pollInfo *AFDPollInfo) error { + pollInfo.Timeout = math.MaxInt64 + pollInfo.NumberOfHandles = 1 + pollInfo.Exclusive = 0 + pollInfo.Handles[0].Handle = baseSocket + pollInfo.Handles[0].Events = events + pollInfo.Handles[0].Status = 0 + + size := uint32(unsafe.Sizeof(*pollInfo)) + + err := NtDeviceIoControlFile( + a.handle, + 0, + 0, + uintptr(unsafe.Pointer(iosb)), + iosb, + IOCTL_AFD_POLL, + unsafe.Pointer(pollInfo), + size, + unsafe.Pointer(pollInfo), + size, + ) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_PENDING { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Cancel(ioStatusBlock *windows.IO_STATUS_BLOCK) error { + if uint32(ioStatusBlock.Status) != STATUS_PENDING { + return nil + } + var cancelIOStatusBlock windows.IO_STATUS_BLOCK + err := NtCancelIoFileEx(a.handle, ioStatusBlock, &cancelIOStatusBlock) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_CANCELLED || uint32(ntstatus) == STATUS_NOT_FOUND { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Close() error { + return windows.CloseHandle(a.handle) +} diff --git a/common/wepoll/pinner.go b/common/wepoll/pinner.go new file mode 100644 index 00000000..58b76686 --- /dev/null +++ b/common/wepoll/pinner.go @@ -0,0 +1,7 @@ +//go:build go1.21 + +package wepoll + +import "runtime" + +type Pinner = runtime.Pinner diff --git a/common/wepoll/pinner_compat.go b/common/wepoll/pinner_compat.go new file mode 100644 index 00000000..a51a9fa6 --- /dev/null +++ b/common/wepoll/pinner_compat.go @@ -0,0 +1,9 @@ +//go:build !go1.21 + +package wepoll + +type Pinner struct{} + +func (p *Pinner) Pin(pointer any) {} + +func (p *Pinner) Unpin() {} diff --git a/common/wepoll/socket_windows.go b/common/wepoll/socket_windows.go new file mode 100644 index 00000000..e5655990 --- /dev/null +++ b/common/wepoll/socket_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package wepoll + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +func GetBaseSocket(socket windows.Handle) (windows.Handle, error) { + var baseSocket windows.Handle + var bytesReturned uint32 + + for { + err := windows.WSAIoctl( + socket, + SIO_BASE_HANDLE, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + err = windows.WSAIoctl( + socket, + SIO_BSP_HANDLE_POLL, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + return socket, nil + } + } + + if baseSocket == socket { + return baseSocket, nil + } + socket = baseSocket + } +} diff --git a/common/wepoll/syscall_windows.go b/common/wepoll/syscall_windows.go new file mode 100644 index 00000000..948dd00e --- /dev/null +++ b/common/wepoll/syscall_windows.go @@ -0,0 +1,8 @@ +package wepoll + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +//sys NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) = ntdll.NtCreateFile +//sys NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) = ntdll.NtDeviceIoControlFile +//sys NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) = ntdll.NtCancelIoFileEx +//sys GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) = kernel32.GetQueuedCompletionStatusEx diff --git a/common/wepoll/types_windows.go b/common/wepoll/types_windows.go new file mode 100644 index 00000000..aad2d79a --- /dev/null +++ b/common/wepoll/types_windows.go @@ -0,0 +1,64 @@ +//go:build windows + +package wepoll + +import "golang.org/x/sys/windows" + +const ( + IOCTL_AFD_POLL = 0x00012024 + + AFD_POLL_RECEIVE = 0x0001 + AFD_POLL_RECEIVE_EXPEDITED = 0x0002 + AFD_POLL_SEND = 0x0004 + AFD_POLL_DISCONNECT = 0x0008 + AFD_POLL_ABORT = 0x0010 + AFD_POLL_LOCAL_CLOSE = 0x0020 + AFD_POLL_ACCEPT = 0x0080 + AFD_POLL_CONNECT_FAIL = 0x0100 + + SIO_BASE_HANDLE = 0x48000022 + SIO_BSP_HANDLE_POLL = 0x4800001D + + STATUS_PENDING = 0x00000103 + STATUS_CANCELLED = 0xC0000120 + STATUS_NOT_FOUND = 0xC0000225 + + FILE_OPEN = 0x00000001 + + OBJ_CASE_INSENSITIVE = 0x00000040 +) + +type AFDPollHandleInfo struct { + Handle windows.Handle + Events uint32 + Status uint32 +} + +type AFDPollInfo struct { + Timeout int64 + NumberOfHandles uint32 + Exclusive uint32 + Handles [1]AFDPollHandleInfo +} + +type OverlappedEntry struct { + CompletionKey uintptr + Overlapped *windows.Overlapped + Internal uintptr + NumberOfBytesTransferred uint32 +} + +type UnicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +type ObjectAttributes struct { + Length uint32 + RootDirectory windows.Handle + ObjectName *UnicodeString + Attributes uint32 + SecurityDescriptor uintptr + SecurityQualityOfService uintptr +} diff --git a/common/wepoll/wepoll_test.go b/common/wepoll/wepoll_test.go new file mode 100644 index 00000000..0a0ec023 --- /dev/null +++ b/common/wepoll/wepoll_test.go @@ -0,0 +1,335 @@ +//go:build windows + +package wepoll + +import ( + "net" + "syscall" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func createTestIOCP(t *testing.T) windows.Handle { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + require.NoError(t, err) + t.Cleanup(func() { + windows.CloseHandle(iocp) + }) + return iocp +} + +func getSocketHandle(t *testing.T, conn net.PacketConn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func getTCPSocketHandle(t *testing.T, conn net.Conn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func TestNewAFD(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "test") + require.NoError(t, err) + require.NotNil(t, afd) + + err = afd.Close() + require.NoError(t, err) +} + +func TestNewAFD_MultipleTimes(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd1, err := NewAFD(iocp, "test1") + require.NoError(t, err) + defer afd1.Close() + + afd2, err := NewAFD(iocp, "test2") + require.NoError(t, err) + defer afd2.Close() + + afd3, err := NewAFD(iocp, "test3") + require.NoError(t, err) + defer afd3.Close() +} + +func TestGetBaseSocket_UDP(t *testing.T) { + t.Parallel() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestGetBaseSocket_TCP(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + go func() { + conn, err := listener.Accept() + if err == nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + handle := getTCPSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestAFD_Poll_ReceiveEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "poll_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + require.Equal(t, uintptr(unsafe.Pointer(&state.iosb)), uintptr(unsafe.Pointer(entries[0].Overlapped))) +} + +func TestAFD_Cancel(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "cancel_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + err = afd.Cancel(&state.iosb) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 1000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) +} + +func TestAFD_Close(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "close_test") + require.NoError(t, err) + + err = afd.Close() + require.NoError(t, err) +} + +func TestGetQueuedCompletionStatusEx_Timeout(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + + start := time.Now() + err := GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 100, false) + elapsed := time.Since(start) + + require.Error(t, err) + require.GreaterOrEqual(t, elapsed, 50*time.Millisecond) +} + +func TestGetQueuedCompletionStatusEx_MultipleEntries(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "multi_test") + require.NoError(t, err) + defer afd.Close() + + const numConns = 3 + conns := make([]net.PacketConn, numConns) + states := make([]struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + }, numConns) + pinners := make([]Pinner, numConns) + + for i := 0; i < numConns; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + pinners[i].Pin(&states[i]) + defer pinners[i].Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &states[i].iosb, &states[i].pollInfo) + require.NoError(t, err) + } + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + for i := 0; i < numConns; i++ { + _, err = sender.WriteTo([]byte("test"), conns[i].LocalAddr()) + require.NoError(t, err) + } + + entries := make([]OverlappedEntry, 8) + var numRemoved uint32 + received := 0 + for received < numConns { + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 8, &numRemoved, 5000, false) + require.NoError(t, err) + received += int(numRemoved) + } + require.Equal(t, numConns, received) +} + +func TestAFD_Poll_DisconnectEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "disconnect_test") + require.NoError(t, err) + defer afd.Close() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + conn, err := listener.Accept() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + conn.Close() + }() + + client, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer client.Close() + + handle := getTCPSocketHandle(t, client) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + + <-serverDone +} diff --git a/common/wepoll/zsyscall_windows.go b/common/wepoll/zsyscall_windows.go new file mode 100644 index 00000000..dac75d17 --- /dev/null +++ b/common/wepoll/zsyscall_windows.go @@ -0,0 +1,84 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package wepoll + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + + procGetQueuedCompletionStatusEx = modkernel32.NewProc("GetQueuedCompletionStatusEx") + procNtCancelIoFileEx = modntdll.NewProc("NtCancelIoFileEx") + procNtCreateFile = modntdll.NewProc("NtCreateFile") + procNtDeviceIoControlFile = modntdll.NewProc("NtDeviceIoControlFile") +) + +func GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) { + var _p0 uint32 + if alertable { + _p0 = 1 + } + r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatusEx.Addr(), 6, uintptr(cphandle), uintptr(unsafe.Pointer(entries)), uintptr(count), uintptr(unsafe.Pointer(numRemoved)), uintptr(timeout), uintptr(_p0)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) { + r0, _, _ := syscall.Syscall(procNtCancelIoFileEx.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(ioRequestToCancel)), uintptr(unsafe.Pointer(ioStatusBlock))) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtCreateFile.Addr(), 11, uintptr(unsafe.Pointer(handle)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(unsafe.Pointer(allocationSize)), uintptr(attributes), uintptr(share), uintptr(disposition), uintptr(options), uintptr(eaBuffer), uintptr(eaLength), 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtDeviceIoControlFile.Addr(), 10, uintptr(handle), uintptr(event), uintptr(apcRoutine), uintptr(apcContext), uintptr(unsafe.Pointer(ioStatusBlock)), uintptr(ioControlCode), uintptr(inputBuffer), uintptr(inputBufferLength), uintptr(outputBuffer), uintptr(outputBufferLength), 0, 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} From 3bacbe632b0c05ff31fb19ec812c66ea7476902a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 15:53:09 +0800 Subject: [PATCH 02/13] Add stream reactor --- .../{channel_demux.go => channel_poller.go} | 94 ++-- common/bufio/fd_demux_stub.go | 25 - common/bufio/fd_handler.go | 9 + ...fd_demux_darwin.go => fd_poller_darwin.go} | 134 ++--- .../{fd_demux_linux.go => fd_poller_linux.go} | 142 +++--- common/bufio/fd_poller_stub.go | 25 + ..._demux_windows.go => fd_poller_windows.go} | 134 ++--- common/bufio/packet_reactor.go | 129 ++--- common/bufio/packet_reactor_test.go | 12 +- common/bufio/stream_reactor.go | 412 ++++++++++++++++ common/bufio/stream_reactor_test.go | 463 ++++++++++++++++++ common/network/packet_pollable.go | 20 + common/network/read_notifier.go | 21 - common/network/stream_pollable.go | 17 + common/udpnat2/conn.go | 12 +- 15 files changed, 1283 insertions(+), 366 deletions(-) rename common/bufio/{channel_demux.go => channel_poller.go} (51%) delete mode 100644 common/bufio/fd_demux_stub.go create mode 100644 common/bufio/fd_handler.go rename common/bufio/{fd_demux_darwin.go => fd_poller_darwin.go} (54%) rename common/bufio/{fd_demux_linux.go => fd_poller_linux.go} (53%) create mode 100644 common/bufio/fd_poller_stub.go rename common/bufio/{fd_demux_windows.go => fd_poller_windows.go} (59%) create mode 100644 common/bufio/stream_reactor.go create mode 100644 common/bufio/stream_reactor_test.go create mode 100644 common/network/packet_pollable.go delete mode 100644 common/network/read_notifier.go create mode 100644 common/network/stream_pollable.go diff --git a/common/bufio/channel_demux.go b/common/bufio/channel_poller.go similarity index 51% rename from common/bufio/channel_demux.go rename to common/bufio/channel_poller.go index 2876083c..ceece6a2 100644 --- a/common/bufio/channel_demux.go +++ b/common/bufio/channel_poller.go @@ -14,7 +14,7 @@ type channelDemuxEntry struct { stream *reactorStream } -type ChannelDemultiplexer struct { +type ChannelPoller struct { ctx context.Context cancel context.CancelFunc mutex sync.Mutex @@ -25,22 +25,22 @@ type ChannelDemultiplexer struct { wg sync.WaitGroup } -func NewChannelDemultiplexer(ctx context.Context) *ChannelDemultiplexer { +func NewChannelPoller(ctx context.Context) *ChannelPoller { ctx, cancel := context.WithCancel(ctx) - demux := &ChannelDemultiplexer{ + poller := &ChannelPoller{ ctx: ctx, cancel: cancel, entries: make(map[<-chan *N.PacketBuffer]*channelDemuxEntry), updateChan: make(chan struct{}, 1), } - return demux + return poller } -func (d *ChannelDemultiplexer) Add(stream *reactorStream, channel <-chan *N.PacketBuffer) { - d.mutex.Lock() +func (p *ChannelPoller) Add(stream *reactorStream, channel <-chan *N.PacketBuffer) { + p.mutex.Lock() - if d.closed.Load() { - d.mutex.Unlock() + if p.closed.Load() { + p.mutex.Unlock() return } @@ -48,89 +48,89 @@ func (d *ChannelDemultiplexer) Add(stream *reactorStream, channel <-chan *N.Pack channel: channel, stream: stream, } - d.entries[channel] = entry - if !d.running { - d.running = true - d.wg.Add(1) - go d.run() + p.entries[channel] = entry + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() } - d.mutex.Unlock() - d.signalUpdate() + p.mutex.Unlock() + p.signalUpdate() } -func (d *ChannelDemultiplexer) Remove(channel <-chan *N.PacketBuffer) { - d.mutex.Lock() - delete(d.entries, channel) - d.mutex.Unlock() - d.signalUpdate() +func (p *ChannelPoller) Remove(channel <-chan *N.PacketBuffer) { + p.mutex.Lock() + delete(p.entries, channel) + p.mutex.Unlock() + p.signalUpdate() } -func (d *ChannelDemultiplexer) signalUpdate() { +func (p *ChannelPoller) signalUpdate() { select { - case d.updateChan <- struct{}{}: + case p.updateChan <- struct{}{}: default: } } -func (d *ChannelDemultiplexer) Close() error { - d.mutex.Lock() - d.closed.Store(true) - d.mutex.Unlock() +func (p *ChannelPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() - d.cancel() - d.signalUpdate() - d.wg.Wait() + p.cancel() + p.signalUpdate() + p.wg.Wait() return nil } -func (d *ChannelDemultiplexer) run() { - defer d.wg.Done() +func (p *ChannelPoller) run() { + defer p.wg.Done() for { - d.mutex.Lock() - if len(d.entries) == 0 { - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() return } - cases := make([]reflect.SelectCase, 0, len(d.entries)+2) + cases := make([]reflect.SelectCase, 0, len(p.entries)+2) cases = append(cases, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(d.ctx.Done()), + Chan: reflect.ValueOf(p.ctx.Done()), }) cases = append(cases, reflect.SelectCase{ Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(d.updateChan), + Chan: reflect.ValueOf(p.updateChan), }) - entryList := make([]*channelDemuxEntry, 0, len(d.entries)) - for _, entry := range d.entries { + entryList := make([]*channelDemuxEntry, 0, len(p.entries)) + for _, entry := range p.entries { cases = append(cases, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(entry.channel), }) entryList = append(entryList, entry) } - d.mutex.Unlock() + p.mutex.Unlock() chosen, recv, recvOK := reflect.Select(cases) switch chosen { case 0: - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return case 1: continue default: entry := entryList[chosen-2] - d.mutex.Lock() - delete(d.entries, entry.channel) - d.mutex.Unlock() + p.mutex.Lock() + delete(p.entries, entry.channel) + p.mutex.Unlock() if recvOK { packet := recv.Interface().(*N.PacketBuffer) diff --git a/common/bufio/fd_demux_stub.go b/common/bufio/fd_demux_stub.go deleted file mode 100644 index d2248cf5..00000000 --- a/common/bufio/fd_demux_stub.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build !linux && !darwin && !windows - -package bufio - -import ( - "context" - - E "github.com/sagernet/sing/common/exceptions" -) - -type FDDemultiplexer struct{} - -func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { - return nil, E.New("FDDemultiplexer not supported on this platform") -} - -func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { - return E.New("FDDemultiplexer not supported on this platform") -} - -func (d *FDDemultiplexer) Remove(fd int) {} - -func (d *FDDemultiplexer) Close() error { - return nil -} diff --git a/common/bufio/fd_handler.go b/common/bufio/fd_handler.go new file mode 100644 index 00000000..4c41677f --- /dev/null +++ b/common/bufio/fd_handler.go @@ -0,0 +1,9 @@ +package bufio + +// FDHandler is the interface for handling FD ready events +// Implemented by both reactorStream (UDP) and streamDirection (TCP) +type FDHandler interface { + // HandleFDEvent is called when the FD has data ready to read + // The handler should start processing data in a new goroutine + HandleFDEvent() +} diff --git a/common/bufio/fd_demux_darwin.go b/common/bufio/fd_poller_darwin.go similarity index 54% rename from common/bufio/fd_demux_darwin.go rename to common/bufio/fd_poller_darwin.go index 3e1c7876..fd2422f9 100644 --- a/common/bufio/fd_demux_darwin.go +++ b/common/bufio/fd_poller_darwin.go @@ -11,11 +11,11 @@ import ( ) type fdDemuxEntry struct { - fd int - stream *reactorStream + fd int + handler FDHandler } -type FDDemultiplexer struct { +type FDPoller struct { ctx context.Context cancel context.CancelFunc kqueueFD int @@ -27,7 +27,7 @@ type FDDemultiplexer struct { pipeFDs [2]int } -func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { +func NewFDPoller(ctx context.Context) (*FDPoller, error) { kqueueFD, err := unix.Kqueue() if err != nil { return nil, err @@ -68,25 +68,25 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { } ctx, cancel := context.WithCancel(ctx) - demux := &FDDemultiplexer{ + poller := &FDPoller{ ctx: ctx, cancel: cancel, kqueueFD: kqueueFD, entries: make(map[int]*fdDemuxEntry), pipeFDs: pipeFDs, } - return demux, nil + return poller, nil } -func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() - if d.closed.Load() { + if p.closed.Load() { return unix.EINVAL } - _, err := unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + _, err := unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ Ident: uint64(fd), Filter: unix.EVFILT_READ, Flags: unix.EV_ADD, @@ -96,90 +96,90 @@ func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { } entry := &fdDemuxEntry{ - fd: fd, - stream: stream, + fd: fd, + handler: handler, } - d.entries[fd] = entry + p.entries[fd] = entry - if !d.running { - d.running = true - d.wg.Add(1) - go d.run() + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() } return nil } -func (d *FDDemultiplexer) Remove(fd int) { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() - _, ok := d.entries[fd] + _, ok := p.entries[fd] if !ok { return } - unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ Ident: uint64(fd), Filter: unix.EVFILT_READ, Flags: unix.EV_DELETE, }}, nil, nil) - delete(d.entries, fd) + delete(p.entries, fd) } -func (d *FDDemultiplexer) wakeup() { - unix.Write(d.pipeFDs[1], []byte{0}) +func (p *FDPoller) wakeup() { + unix.Write(p.pipeFDs[1], []byte{0}) } -func (d *FDDemultiplexer) Close() error { - d.mutex.Lock() - d.closed.Store(true) - d.mutex.Unlock() +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() - d.cancel() - d.wakeup() - d.wg.Wait() + p.cancel() + p.wakeup() + p.wg.Wait() - d.mutex.Lock() - defer d.mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() - if d.kqueueFD != -1 { - unix.Close(d.kqueueFD) - d.kqueueFD = -1 + if p.kqueueFD != -1 { + unix.Close(p.kqueueFD) + p.kqueueFD = -1 } - if d.pipeFDs[0] != -1 { - unix.Close(d.pipeFDs[0]) - unix.Close(d.pipeFDs[1]) - d.pipeFDs[0] = -1 - d.pipeFDs[1] = -1 + if p.pipeFDs[0] != -1 { + unix.Close(p.pipeFDs[0]) + unix.Close(p.pipeFDs[1]) + p.pipeFDs[0] = -1 + p.pipeFDs[1] = -1 } return nil } -func (d *FDDemultiplexer) run() { - defer d.wg.Done() +func (p *FDPoller) run() { + defer p.wg.Done() events := make([]unix.Kevent_t, 64) var buffer [1]byte for { select { - case <-d.ctx.Done(): - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + case <-p.ctx.Done(): + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return default: } - n, err := unix.Kevent(d.kqueueFD, nil, events, nil) + n, err := unix.Kevent(p.kqueueFD, nil, events, nil) if err != nil { if err == unix.EINTR { continue } - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return } @@ -187,8 +187,8 @@ func (d *FDDemultiplexer) run() { event := events[i] fd := int(event.Ident) - if fd == d.pipeFDs[0] { - unix.Read(d.pipeFDs[0], buffer[:]) + if fd == p.pipeFDs[0] { + unix.Read(p.pipeFDs[0], buffer[:]) continue } @@ -196,30 +196,30 @@ func (d *FDDemultiplexer) run() { continue } - d.mutex.Lock() - entry, ok := d.entries[fd] + p.mutex.Lock() + entry, ok := p.entries[fd] if !ok { - d.mutex.Unlock() + p.mutex.Unlock() continue } - unix.Kevent(d.kqueueFD, []unix.Kevent_t{{ + unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ Ident: uint64(fd), Filter: unix.EVFILT_READ, Flags: unix.EV_DELETE, }}, nil, nil) - delete(d.entries, fd) - d.mutex.Unlock() + delete(p.entries, fd) + p.mutex.Unlock() - go entry.stream.runActiveLoop(nil) + go entry.handler.HandleFDEvent() } - d.mutex.Lock() - if len(d.entries) == 0 { - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() return } - d.mutex.Unlock() + p.mutex.Unlock() } } diff --git a/common/bufio/fd_demux_linux.go b/common/bufio/fd_poller_linux.go similarity index 53% rename from common/bufio/fd_demux_linux.go rename to common/bufio/fd_poller_linux.go index 2c5e0afa..a2e245b9 100644 --- a/common/bufio/fd_demux_linux.go +++ b/common/bufio/fd_poller_linux.go @@ -14,10 +14,10 @@ import ( type fdDemuxEntry struct { fd int registrationID uint64 - stream *reactorStream + handler FDHandler } -type FDDemultiplexer struct { +type FDPoller struct { ctx context.Context cancel context.CancelFunc epollFD int @@ -31,7 +31,7 @@ type FDDemultiplexer struct { pipeFDs [2]int } -func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { +func NewFDPoller(ctx context.Context) (*FDPoller, error) { epollFD, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) if err != nil { return nil, err @@ -55,7 +55,7 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { } ctx, cancel := context.WithCancel(ctx) - demux := &FDDemultiplexer{ + poller := &FDPoller{ ctx: ctx, cancel: cancel, epollFD: epollFD, @@ -63,24 +63,24 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { registrationToFD: make(map[uint64]int), pipeFDs: pipeFDs, } - return demux, nil + return poller, nil } -func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() - if d.closed.Load() { + if p.closed.Load() { return unix.EINVAL } - d.registrationCounter++ - registrationID := d.registrationCounter + p.registrationCounter++ + registrationID := p.registrationCounter event := &unix.EpollEvent{Events: unix.EPOLLIN | unix.EPOLLRDHUP} *(*uint64)(unsafe.Pointer(&event.Fd)) = registrationID - err := unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_ADD, fd, event) + err := unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_ADD, fd, event) if err != nil { return err } @@ -88,87 +88,87 @@ func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { entry := &fdDemuxEntry{ fd: fd, registrationID: registrationID, - stream: stream, + handler: handler, } - d.entries[fd] = entry - d.registrationToFD[registrationID] = fd + p.entries[fd] = entry + p.registrationToFD[registrationID] = fd - if !d.running { - d.running = true - d.wg.Add(1) - go d.run() + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() } return nil } -func (d *FDDemultiplexer) Remove(fd int) { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() - entry, ok := d.entries[fd] + entry, ok := p.entries[fd] if !ok { return } - unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) - delete(d.registrationToFD, entry.registrationID) - delete(d.entries, fd) + unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(p.registrationToFD, entry.registrationID) + delete(p.entries, fd) } -func (d *FDDemultiplexer) wakeup() { - unix.Write(d.pipeFDs[1], []byte{0}) +func (p *FDPoller) wakeup() { + unix.Write(p.pipeFDs[1], []byte{0}) } -func (d *FDDemultiplexer) Close() error { - d.mutex.Lock() - d.closed.Store(true) - d.mutex.Unlock() +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() - d.cancel() - d.wakeup() - d.wg.Wait() + p.cancel() + p.wakeup() + p.wg.Wait() - d.mutex.Lock() - defer d.mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() - if d.epollFD != -1 { - unix.Close(d.epollFD) - d.epollFD = -1 + if p.epollFD != -1 { + unix.Close(p.epollFD) + p.epollFD = -1 } - if d.pipeFDs[0] != -1 { - unix.Close(d.pipeFDs[0]) - unix.Close(d.pipeFDs[1]) - d.pipeFDs[0] = -1 - d.pipeFDs[1] = -1 + if p.pipeFDs[0] != -1 { + unix.Close(p.pipeFDs[0]) + unix.Close(p.pipeFDs[1]) + p.pipeFDs[0] = -1 + p.pipeFDs[1] = -1 } return nil } -func (d *FDDemultiplexer) run() { - defer d.wg.Done() +func (p *FDPoller) run() { + defer p.wg.Done() events := make([]unix.EpollEvent, 64) var buffer [1]byte for { select { - case <-d.ctx.Done(): - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + case <-p.ctx.Done(): + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return default: } - n, err := unix.EpollWait(d.epollFD, events, -1) + n, err := unix.EpollWait(p.epollFD, events, -1) if err != nil { if err == unix.EINTR { continue } - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return } @@ -177,7 +177,7 @@ func (d *FDDemultiplexer) run() { registrationID := *(*uint64)(unsafe.Pointer(&event.Fd)) if registrationID == 0 { - unix.Read(d.pipeFDs[0], buffer[:]) + unix.Read(p.pipeFDs[0], buffer[:]) continue } @@ -185,33 +185,33 @@ func (d *FDDemultiplexer) run() { continue } - d.mutex.Lock() - fd, ok := d.registrationToFD[registrationID] + p.mutex.Lock() + fd, ok := p.registrationToFD[registrationID] if !ok { - d.mutex.Unlock() + p.mutex.Unlock() continue } - entry := d.entries[fd] + entry := p.entries[fd] if entry == nil || entry.registrationID != registrationID { - d.mutex.Unlock() + p.mutex.Unlock() continue } - unix.EpollCtl(d.epollFD, unix.EPOLL_CTL_DEL, fd, nil) - delete(d.registrationToFD, registrationID) - delete(d.entries, fd) - d.mutex.Unlock() + unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(p.registrationToFD, registrationID) + delete(p.entries, fd) + p.mutex.Unlock() - go entry.stream.runActiveLoop(nil) + go entry.handler.HandleFDEvent() } - d.mutex.Lock() - if len(d.entries) == 0 { - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() return } - d.mutex.Unlock() + p.mutex.Unlock() } } diff --git a/common/bufio/fd_poller_stub.go b/common/bufio/fd_poller_stub.go new file mode 100644 index 00000000..11921482 --- /dev/null +++ b/common/bufio/fd_poller_stub.go @@ -0,0 +1,25 @@ +//go:build !linux && !darwin && !windows + +package bufio + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" +) + +type FDPoller struct{} + +func NewFDPoller(ctx context.Context) (*FDPoller, error) { + return nil, E.New("FDPoller not supported on this platform") +} + +func (p *FDPoller) Add(handler FDHandler, fd int) error { + return E.New("FDPoller not supported on this platform") +} + +func (p *FDPoller) Remove(fd int) {} + +func (p *FDPoller) Close() error { + return nil +} diff --git a/common/bufio/fd_demux_windows.go b/common/bufio/fd_poller_windows.go similarity index 59% rename from common/bufio/fd_demux_windows.go rename to common/bufio/fd_poller_windows.go index 06795ebe..434b38b7 100644 --- a/common/bufio/fd_demux_windows.go +++ b/common/bufio/fd_poller_windows.go @@ -16,7 +16,7 @@ import ( type fdDemuxEntry struct { ioStatusBlock windows.IO_STATUS_BLOCK pollInfo wepoll.AFDPollInfo - stream *reactorStream + handler FDHandler fd int handle windows.Handle baseHandle windows.Handle @@ -25,7 +25,7 @@ type fdDemuxEntry struct { pinner wepoll.Pinner } -type FDDemultiplexer struct { +type FDPoller struct { ctx context.Context cancel context.CancelFunc iocp windows.Handle @@ -38,7 +38,7 @@ type FDDemultiplexer struct { wg sync.WaitGroup } -func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { +func NewFDPoller(ctx context.Context) (*FDPoller, error) { iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { return nil, err @@ -51,21 +51,21 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) { } ctx, cancel := context.WithCancel(ctx) - demux := &FDDemultiplexer{ + poller := &FDPoller{ ctx: ctx, cancel: cancel, iocp: iocp, afd: afd, entries: make(map[int]*fdDemuxEntry), } - return demux, nil + return poller, nil } -func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() - if d.closed.Load() { + if p.closed.Load() { return windows.ERROR_INVALID_HANDLE } @@ -75,11 +75,11 @@ func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { return err } - d.registrationCounter++ - registrationID := d.registrationCounter + p.registrationCounter++ + registrationID := p.registrationCounter entry := &fdDemuxEntry{ - stream: stream, + handler: handler, fd: fd, handle: handle, baseHandle: baseHandle, @@ -89,91 +89,91 @@ func (d *FDDemultiplexer) Add(stream *reactorStream, fd int) error { entry.pinner.Pin(entry) events := uint32(wepoll.AFD_POLL_RECEIVE | wepoll.AFD_POLL_DISCONNECT | wepoll.AFD_POLL_ABORT | wepoll.AFD_POLL_LOCAL_CLOSE) - err = d.afd.Poll(baseHandle, events, &entry.ioStatusBlock, &entry.pollInfo) + err = p.afd.Poll(baseHandle, events, &entry.ioStatusBlock, &entry.pollInfo) if err != nil { entry.pinner.Unpin() return err } - d.entries[fd] = entry + p.entries[fd] = entry - if !d.running { - d.running = true - d.wg.Add(1) - go d.run() + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() } return nil } -func (d *FDDemultiplexer) Remove(fd int) { - d.mutex.Lock() - defer d.mutex.Unlock() +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() - entry, ok := d.entries[fd] + entry, ok := p.entries[fd] if !ok { return } entry.cancelled = true - if d.afd != nil { - d.afd.Cancel(&entry.ioStatusBlock) + if p.afd != nil { + p.afd.Cancel(&entry.ioStatusBlock) } } -func (d *FDDemultiplexer) wakeup() { - windows.PostQueuedCompletionStatus(d.iocp, 0, 0, nil) +func (p *FDPoller) wakeup() { + windows.PostQueuedCompletionStatus(p.iocp, 0, 0, nil) } -func (d *FDDemultiplexer) Close() error { - d.mutex.Lock() - d.closed.Store(true) - d.mutex.Unlock() +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() - d.cancel() - d.wakeup() - d.wg.Wait() + p.cancel() + p.wakeup() + p.wg.Wait() - d.mutex.Lock() - defer d.mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() - for fd, entry := range d.entries { + for fd, entry := range p.entries { entry.pinner.Unpin() - delete(d.entries, fd) + delete(p.entries, fd) } - if d.afd != nil { - d.afd.Close() - d.afd = nil + if p.afd != nil { + p.afd.Close() + p.afd = nil } - if d.iocp != 0 { - windows.CloseHandle(d.iocp) - d.iocp = 0 + if p.iocp != 0 { + windows.CloseHandle(p.iocp) + p.iocp = 0 } return nil } -func (d *FDDemultiplexer) run() { - defer d.wg.Done() +func (p *FDPoller) run() { + defer p.wg.Done() completions := make([]wepoll.OverlappedEntry, 64) for { select { - case <-d.ctx.Done(): - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + case <-p.ctx.Done(): + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return default: } var numRemoved uint32 - err := wepoll.GetQueuedCompletionStatusEx(d.iocp, &completions[0], 64, &numRemoved, windows.INFINITE, false) + err := wepoll.GetQueuedCompletionStatusEx(p.iocp, &completions[0], 64, &numRemoved, windows.INFINITE, false) if err != nil { - d.mutex.Lock() - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + p.running = false + p.mutex.Unlock() return } @@ -186,42 +186,42 @@ func (d *FDDemultiplexer) run() { entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) - d.mutex.Lock() + p.mutex.Lock() - if d.entries[entry.fd] != entry { - d.mutex.Unlock() + if p.entries[entry.fd] != entry { + p.mutex.Unlock() continue } entry.pinner.Unpin() - delete(d.entries, entry.fd) + delete(p.entries, entry.fd) if entry.cancelled { - d.mutex.Unlock() + p.mutex.Unlock() continue } if uint32(entry.ioStatusBlock.Status) == wepoll.STATUS_CANCELLED { - d.mutex.Unlock() + p.mutex.Unlock() continue } events := entry.pollInfo.Handles[0].Events if events&(wepoll.AFD_POLL_RECEIVE|wepoll.AFD_POLL_DISCONNECT|wepoll.AFD_POLL_ABORT|wepoll.AFD_POLL_LOCAL_CLOSE) == 0 { - d.mutex.Unlock() + p.mutex.Unlock() continue } - d.mutex.Unlock() - go entry.stream.runActiveLoop(nil) + p.mutex.Unlock() + go entry.handler.HandleFDEvent() } - d.mutex.Lock() - if len(d.entries) == 0 { - d.running = false - d.mutex.Unlock() + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() return } - d.mutex.Unlock() + p.mutex.Unlock() } } diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index 4ded4ed6..15398b9a 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -24,38 +24,38 @@ const ( ) type PacketReactor struct { - ctx context.Context - cancel context.CancelFunc - channelDemux *ChannelDemultiplexer - fdDemux *FDDemultiplexer - fdDemuxOnce sync.Once - fdDemuxErr error + ctx context.Context + cancel context.CancelFunc + channelPoller *ChannelPoller + fdPoller *FDPoller + fdPollerOnce sync.Once + fdPollerErr error } func NewPacketReactor(ctx context.Context) *PacketReactor { ctx, cancel := context.WithCancel(ctx) return &PacketReactor{ - ctx: ctx, - cancel: cancel, - channelDemux: NewChannelDemultiplexer(ctx), + ctx: ctx, + cancel: cancel, + channelPoller: NewChannelPoller(ctx), } } -func (r *PacketReactor) getFDDemultiplexer() (*FDDemultiplexer, error) { - r.fdDemuxOnce.Do(func() { - r.fdDemux, r.fdDemuxErr = NewFDDemultiplexer(r.ctx) +func (r *PacketReactor) getFDPoller() (*FDPoller, error) { + r.fdPollerOnce.Do(func() { + r.fdPoller, r.fdPollerErr = NewFDPoller(r.ctx) }) - return r.fdDemux, r.fdDemuxErr + return r.fdPoller, r.fdPollerErr } func (r *PacketReactor) Close() error { r.cancel() var errs []error - if r.channelDemux != nil { - errs = append(errs, r.channelDemux.Close()) + if r.channelPoller != nil { + errs = append(errs, r.channelPoller.Close()) } - if r.fdDemux != nil { - errs = append(errs, r.fdDemux.Close()) + if r.fdPoller != nil { + errs = append(errs, r.fdPoller.Close()) } return E.Errors(errs...) } @@ -80,7 +80,7 @@ type reactorStream struct { destination N.PacketWriter originSource N.PacketReader - notifier N.ReadNotifier + pollable N.PacketPollable options N.ReadWaitOptions readWaiter N.PacketReadWaiter readCounters []N.CountFunc @@ -159,29 +159,31 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe stream.readWaiter.InitializeReadWaiter(stream.options) } - if notifierSource, ok := source.(N.ReadNotifierSource); ok { - stream.notifier = notifierSource.CreateReadNotifier() + if pollable, ok := source.(N.PacketPollable); ok { + stream.pollable = pollable + } else if creator, ok := source.(N.PacketPollableCreator); ok { + stream.pollable, _ = creator.CreatePacketPollable() } return stream } func (r *PacketReactor) registerStream(stream *reactorStream) { - if stream.notifier == nil { + if stream.pollable == nil { go stream.runLegacyCopy() return } - switch notifier := stream.notifier.(type) { - case *N.ChannelNotifier: - r.channelDemux.Add(stream, notifier.Channel) - case *N.FileDescriptorNotifier: - fdDemux, err := r.getFDDemultiplexer() + switch stream.pollable.PollMode() { + case N.PacketPollModeChannel: + r.channelPoller.Add(stream, stream.pollable.PacketChannel()) + case N.PacketPollModeFD: + fdPoller, err := r.getFDPoller() if err != nil { go stream.runLegacyCopy() return } - err = fdDemux.Add(stream, notifier.FD) + err = fdPoller.Add(stream, stream.pollable.FD()) if err != nil { go stream.runLegacyCopy() } @@ -308,26 +310,37 @@ func (s *reactorStream) returnToPool() { return } - switch notifier := s.notifier.(type) { - case *N.ChannelNotifier: - s.connection.reactor.channelDemux.Add(s, notifier.Channel) + if s.pollable == nil { + return + } + + switch s.pollable.PollMode() { + case N.PacketPollModeChannel: + channel := s.pollable.PacketChannel() + s.connection.reactor.channelPoller.Add(s, channel) if s.state.Load() != stateIdle { - s.connection.reactor.channelDemux.Remove(notifier.Channel) + s.connection.reactor.channelPoller.Remove(channel) } - case *N.FileDescriptorNotifier: - if s.connection.reactor.fdDemux != nil { - err := s.connection.reactor.fdDemux.Add(s, notifier.FD) - if err != nil { - s.closeWithError(err) - return - } - if s.state.Load() != stateIdle { - s.connection.reactor.fdDemux.Remove(notifier.FD) - } + case N.PacketPollModeFD: + if s.connection.reactor.fdPoller == nil { + return + } + fd := s.pollable.FD() + err := s.connection.reactor.fdPoller.Add(s, fd) + if err != nil { + s.closeWithError(err) + return + } + if s.state.Load() != stateIdle { + s.connection.reactor.fdPoller.Remove(fd) } } } +func (s *reactorStream) HandleFDEvent() { + s.runActiveLoop(nil) +} + func (s *reactorStream) runLegacyCopy() { _, err := CopyPacket(s.destination, s.source) s.closeWithError(err) @@ -349,7 +362,7 @@ func (c *reactorConnection) closeWithError(err error) { c.download.state.Store(stateClosed) } - c.removeFromDemultiplexers() + c.removeFromPollers() if c.upload != nil { common.Close(c.upload.originSource) @@ -366,25 +379,21 @@ func (c *reactorConnection) closeWithError(err error) { }) } -func (c *reactorConnection) removeFromDemultiplexers() { - if c.upload != nil && c.upload.notifier != nil { - switch notifier := c.upload.notifier.(type) { - case *N.ChannelNotifier: - c.reactor.channelDemux.Remove(notifier.Channel) - case *N.FileDescriptorNotifier: - if c.reactor.fdDemux != nil { - c.reactor.fdDemux.Remove(notifier.FD) - } - } +func (c *reactorConnection) removeFromPollers() { + c.removeStreamFromPoller(c.upload) + c.removeStreamFromPoller(c.download) +} + +func (c *reactorConnection) removeStreamFromPoller(stream *reactorStream) { + if stream == nil || stream.pollable == nil { + return } - if c.download != nil && c.download.notifier != nil { - switch notifier := c.download.notifier.(type) { - case *N.ChannelNotifier: - c.reactor.channelDemux.Remove(notifier.Channel) - case *N.FileDescriptorNotifier: - if c.reactor.fdDemux != nil { - c.reactor.fdDemux.Remove(notifier.FD) - } + switch stream.pollable.PollMode() { + case N.PacketPollModeChannel: + c.reactor.channelPoller.Remove(stream.pollable.PacketChannel()) + case N.PacketPollModeFD: + if c.reactor.fdPoller != nil { + c.reactor.fdPoller.Remove(stream.pollable.FD()) } } } diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go index ae155130..ea95b077 100644 --- a/common/bufio/packet_reactor_test.go +++ b/common/bufio/packet_reactor_test.go @@ -103,8 +103,8 @@ func (p *testPacketPipe) SetWriteDeadline(t time.Time) error { return nil } -func (p *testPacketPipe) CreateReadNotifier() N.ReadNotifier { - return &N.ChannelNotifier{Channel: p.inChan} +func (p *testPacketPipe) PacketChannel() <-chan *N.PacketBuffer { + return p.inChan } func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) { @@ -154,8 +154,8 @@ func (c *fdPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, return c.targetAddr, nil } -func (c *fdPacketConn) CreateReadNotifier() N.ReadNotifier { - return &N.FileDescriptorNotifier{FD: c.fd} +func (c *fdPacketConn) FD() int { + return c.fd } type channelPacketConn struct { @@ -255,8 +255,8 @@ func (c *channelPacketConn) SetReadDeadline(t time.Time) error { return nil } -func (c *channelPacketConn) CreateReadNotifier() N.ReadNotifier { - return &N.ChannelNotifier{Channel: c.packetChan} +func (c *channelPacketConn) PacketChannel() <-chan *N.PacketBuffer { + return c.packetChan } func (c *channelPacketConn) Close() error { diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go new file mode 100644 index 00000000..c5341aec --- /dev/null +++ b/common/bufio/stream_reactor.go @@ -0,0 +1,412 @@ +package bufio + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +const ( + streamBatchReadTimeout = 250 * time.Millisecond +) + +type StreamReactor struct { + ctx context.Context + cancel context.CancelFunc + fdPoller *FDPoller + fdPollerOnce sync.Once + fdPollerErr error +} + +func NewStreamReactor(ctx context.Context) *StreamReactor { + ctx, cancel := context.WithCancel(ctx) + return &StreamReactor{ + ctx: ctx, + cancel: cancel, + } +} + +func (r *StreamReactor) getFDPoller() (*FDPoller, error) { + r.fdPollerOnce.Do(func() { + r.fdPoller, r.fdPollerErr = NewFDPoller(r.ctx) + }) + return r.fdPoller, r.fdPollerErr +} + +func (r *StreamReactor) Close() error { + r.cancel() + if r.fdPoller != nil { + return r.fdPoller.Close() + } + return nil +} + +type streamConnection struct { + ctx context.Context + cancel context.CancelFunc + reactor *StreamReactor + onClose N.CloseHandlerFunc + upload *streamDirection + download *streamDirection + + closeOnce sync.Once + done chan struct{} + err error +} + +type streamDirection struct { + connection *streamConnection + + source io.Reader + destination io.Writer + originSource net.Conn + + pollable N.StreamPollable + options N.ReadWaitOptions + readWaiter N.ReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc + + isUpload bool + state atomic.Int32 +} + +// Copy initiates bidirectional TCP copy with reactor optimization +// It uses splice when available for zero-copy, otherwise falls back to reactor mode +func (r *StreamReactor) Copy(ctx context.Context, source net.Conn, destination net.Conn, onClose N.CloseHandlerFunc) { + // Try splice first (zero-copy optimization) + if r.trySplice(ctx, source, destination, onClose) { + return + } + + // Fall back to reactor mode + ctx, cancel := context.WithCancel(ctx) + conn := &streamConnection{ + ctx: ctx, + cancel: cancel, + reactor: r, + onClose: onClose, + done: make(chan struct{}), + } + + conn.upload = r.prepareDirection(conn, source, destination, source, true) + select { + case <-conn.done: + return + default: + } + + conn.download = r.prepareDirection(conn, destination, source, destination, false) + select { + case <-conn.done: + return + default: + } + + r.registerDirection(conn.upload) + r.registerDirection(conn.download) +} + +func (r *StreamReactor) trySplice(ctx context.Context, source net.Conn, destination net.Conn, onClose N.CloseHandlerFunc) bool { + if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) { + return false + } + + // Both ends support syscall, use traditional copy with splice + go func() { + err := CopyConn(ctx, source, destination) + if onClose != nil { + onClose(err) + } + }() + return true +} + +func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reader, destination io.Writer, originConn net.Conn, isUpload bool) *streamDirection { + direction := &streamDirection{ + connection: conn, + source: source, + destination: destination, + originSource: originConn, + isUpload: isUpload, + } + + // Unwrap counters and handle cached data + for { + source, direction.readCounters = N.UnwrapCountReader(source, direction.readCounters) + destination, direction.writeCounters = N.UnwrapCountWriter(destination, direction.writeCounters) + if cachedReader, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedReader.ReadCached() + if cachedBuffer != nil { + dataLen := cachedBuffer.Len() + _, err := destination.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err != nil { + conn.closeWithError(err) + return direction + } + for _, counter := range direction.readCounters { + counter(int64(dataLen)) + } + for _, counter := range direction.writeCounters { + counter(int64(dataLen)) + } + continue + } + } + break + } + direction.source = source + direction.destination = destination + + direction.options = N.NewReadWaitOptions(source, destination) + + direction.readWaiter, _ = CreateReadWaiter(source) + if direction.readWaiter != nil { + direction.readWaiter.InitializeReadWaiter(direction.options) + } + + // Try to get stream pollable for FD-based idle detection + if pollable, ok := source.(N.StreamPollable); ok { + direction.pollable = pollable + } else if creator, ok := source.(N.StreamPollableCreator); ok { + direction.pollable, _ = creator.CreateStreamPollable() + } + + return direction +} + +func (r *StreamReactor) registerDirection(direction *streamDirection) { + // Check if there's buffered data that needs processing first + if direction.pollable != nil && direction.pollable.Buffered() > 0 { + go direction.runActiveLoop() + return + } + + // Try to register with FD poller + if direction.pollable != nil { + fdPoller, err := r.getFDPoller() + if err == nil { + err = fdPoller.Add(direction, direction.pollable.FD()) + if err == nil { + return + } + } + } + + // Fall back to legacy goroutine copy + go direction.runLegacyCopy() +} + +func (d *streamDirection) runActiveLoop() { + if d.source == nil { + return + } + if !d.state.CompareAndSwap(stateIdle, stateActive) { + return + } + + notFirstTime := false + + for { + if d.state.Load() == stateClosed { + return + } + + // Set batch read timeout + if setter, ok := d.originSource.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Now().Add(streamBatchReadTimeout)) + } + + var ( + buffer *buf.Buffer + err error + ) + + if d.readWaiter != nil { + buffer, err = d.readWaiter.WaitReadBuffer() + } else { + buffer = d.options.NewBuffer() + err = NewExtendedReader(d.source).ReadBuffer(buffer) + if err != nil { + buffer.Release() + buffer = nil + } + } + + if err != nil { + if E.IsTimeout(err) { + // Timeout: check buffer and return to pool + if setter, ok := d.originSource.(interface{ SetReadDeadline(time.Time) error }); ok { + setter.SetReadDeadline(time.Time{}) + } + if d.state.CompareAndSwap(stateActive, stateIdle) { + d.returnToPool() + } + return + } + // EOF or error + if !notFirstTime { + err = N.ReportHandshakeFailure(d.originSource, err) + } + d.handleEOFOrError(err) + return + } + + err = d.writeBufferWithCounters(buffer) + if err != nil { + if !notFirstTime { + err = N.ReportHandshakeFailure(d.originSource, err) + } + d.closeWithError(err) + return + } + notFirstTime = true + } +} + +func (d *streamDirection) writeBufferWithCounters(buffer *buf.Buffer) error { + dataLen := buffer.Len() + d.options.PostReturn(buffer) + err := NewExtendedWriter(d.destination).WriteBuffer(buffer) + if err != nil { + buffer.Leak() + return err + } + + for _, counter := range d.readCounters { + counter(int64(dataLen)) + } + for _, counter := range d.writeCounters { + counter(int64(dataLen)) + } + return nil +} + +func (d *streamDirection) returnToPool() { + if d.state.Load() != stateIdle { + return + } + + // Critical: check if there's buffered data before returning to idle + if d.pollable != nil && d.pollable.Buffered() > 0 { + go d.runActiveLoop() + return + } + + // Safe to wait for FD events + if d.pollable != nil && d.connection.reactor.fdPoller != nil { + err := d.connection.reactor.fdPoller.Add(d, d.pollable.FD()) + if err != nil { + d.closeWithError(err) + return + } + if d.state.Load() != stateIdle { + d.connection.reactor.fdPoller.Remove(d.pollable.FD()) + } + } +} + +func (d *streamDirection) HandleFDEvent() { + d.runActiveLoop() +} + +func (d *streamDirection) runLegacyCopy() { + _, err := Copy(d.destination, d.source) + d.handleEOFOrError(err) +} + +func (d *streamDirection) handleEOFOrError(err error) { + if err == nil || err == io.EOF { + // Graceful EOF: close write direction only (half-close) + d.state.Store(stateClosed) + + // Try half-close on destination + if d.isUpload { + N.CloseWrite(d.connection.download.originSource) + } else { + N.CloseWrite(d.connection.upload.originSource) + } + + d.connection.checkBothClosed() + return + } + + // Error: close entire connection + d.closeWithError(err) +} + +func (d *streamDirection) closeWithError(err error) { + d.connection.closeWithError(err) +} + +func (c *streamConnection) checkBothClosed() { + uploadClosed := c.upload != nil && c.upload.state.Load() == stateClosed + downloadClosed := c.download != nil && c.download.state.Load() == stateClosed + + if uploadClosed && downloadClosed { + c.closeOnce.Do(func() { + c.cancel() + c.removeFromPoller() + + common.Close(c.upload.originSource) + common.Close(c.download.originSource) + + if c.onClose != nil { + c.onClose(nil) + } + + close(c.done) + }) + } +} + +func (c *streamConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + c.err = err + c.cancel() + + if c.upload != nil { + c.upload.state.Store(stateClosed) + } + if c.download != nil { + c.download.state.Store(stateClosed) + } + + c.removeFromPoller() + + if c.upload != nil { + common.Close(c.upload.originSource) + } + if c.download != nil { + common.Close(c.download.originSource) + } + + if c.onClose != nil { + c.onClose(c.err) + } + + close(c.done) + }) +} + +func (c *streamConnection) removeFromPoller() { + if c.reactor.fdPoller == nil { + return + } + + if c.upload != nil && c.upload.pollable != nil { + c.reactor.fdPoller.Remove(c.upload.pollable.FD()) + } + if c.download != nil && c.download.pollable != nil { + c.reactor.fdPoller.Remove(c.download.pollable.FD()) + } +} diff --git a/common/bufio/stream_reactor_test.go b/common/bufio/stream_reactor_test.go new file mode 100644 index 00000000..3b3aae18 --- /dev/null +++ b/common/bufio/stream_reactor_test.go @@ -0,0 +1,463 @@ +//go:build darwin || linux || windows + +package bufio + +import ( + "context" + "crypto/rand" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fdConn wraps a net.Conn to implement StreamNotifier +type fdConn struct { + net.Conn + fd int +} + +func newFDConn(t *testing.T, conn net.Conn) *fdConn { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok, "connection must implement syscall.Conn") + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return &fdConn{ + Conn: conn, + fd: fd, + } +} + +func (c *fdConn) FD() int { + return c.fd +} + +func (c *fdConn) Buffered() int { + return 0 +} + +// bufferedConn wraps a net.Conn with a buffer for testing StreamNotifier +type bufferedConn struct { + net.Conn + buffer *buf.Buffer + bufferMu sync.Mutex + fd int +} + +func newBufferedConn(t *testing.T, conn net.Conn) *bufferedConn { + bc := &bufferedConn{ + Conn: conn, + buffer: buf.New(), + } + if syscallConn, ok := conn.(syscall.Conn); ok { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + rawConn.Control(func(f uintptr) { bc.fd = int(f) }) + } + } + return bc +} + +func (c *bufferedConn) Read(p []byte) (n int, err error) { + c.bufferMu.Lock() + if c.buffer.Len() > 0 { + n = copy(p, c.buffer.Bytes()) + c.buffer.Advance(n) + c.bufferMu.Unlock() + return n, nil + } + c.bufferMu.Unlock() + return c.Conn.Read(p) +} + +func (c *bufferedConn) FD() int { + return c.fd +} + +func (c *bufferedConn) Buffered() int { + c.bufferMu.Lock() + defer c.bufferMu.Unlock() + return c.buffer.Len() +} + +func (c *bufferedConn) Close() error { + c.buffer.Release() + return c.Conn.Close() +} + +func TestStreamReactor_Basic(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + // Create a pair of connected TCP connections + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + var serverConn net.Conn + var serverErr error + serverDone := make(chan struct{}) + go func() { + serverConn, serverErr = listener.Accept() + close(serverDone) + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-serverDone + require.NoError(t, serverErr) + defer serverConn.Close() + + // Create another pair for the destination + listener2, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener2.Close() + + var destServerConn net.Conn + var destServerErr error + destServerDone := make(chan struct{}) + go func() { + destServerConn, destServerErr = listener2.Accept() + close(destServerDone) + }() + + destClientConn, err := net.Dial("tcp", listener2.Addr().String()) + require.NoError(t, err) + defer destClientConn.Close() + + <-destServerDone + require.NoError(t, destServerErr) + defer destServerConn.Close() + + // Test data transfer + testData := make([]byte, 1024) + rand.Read(testData) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, serverConn, destClientConn, func(err error) { + close(closeDone) + }) + + // Write from client to server, should pass through to dest + _, err = clientConn.Write(testData) + require.NoError(t, err) + + // Read from destServerConn + received := make([]byte, len(testData)) + _, err = io.ReadFull(destServerConn, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + // Test reverse direction + reverseData := make([]byte, 512) + rand.Read(reverseData) + + _, err = destServerConn.Write(reverseData) + require.NoError(t, err) + + reverseReceived := make([]byte, len(reverseData)) + _, err = io.ReadFull(clientConn, reverseReceived) + require.NoError(t, err) + assert.Equal(t, reverseData, reverseReceived) + + // Close and wait + clientConn.Close() + destServerConn.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_FDNotifier(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + // Create TCP connection pairs + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap with FD notifier + fdServer1 := newFDConn(t, server1) + fdClient2 := newFDConn(t, client2) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, fdServer1, fdClient2, func(err error) { + close(closeDone) + }) + + // Test data transfer + testData := make([]byte, 2048) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_BufferedReader(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Use buffered conn + bufferedServer1 := newBufferedConn(t, server1) + defer bufferedServer1.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, bufferedServer1, client2, func(err error) { + close(closeDone) + }) + + // Send data + testData := make([]byte, 1024) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_HalfClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send data in one direction + testData := make([]byte, 512) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + // Close client1's write side + if tcpConn, ok := client1.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } else { + client1.Close() + } + + // The other direction should still work for a moment + reverseData := make([]byte, 256) + rand.Read(reverseData) + + _, err = server2.Write(reverseData) + require.NoError(t, err) + + // Eventually both will close + server2.Close() + client1.Close() + + select { + case <-closeDone: + // closeErr should be nil for graceful close + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_MultipleConnections(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + const numConnections = 10 + var wg sync.WaitGroup + var completedCount atomic.Int32 + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send unique data + testData := make([]byte, 256) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + completedCount.Add(1) + case <-time.After(5 * time.Second): + t.Errorf("connection %d: timeout waiting for close callback", idx) + } + }(i) + } + + wg.Wait() + assert.Equal(t, int32(numConnections), completedCount.Load()) +} + +func TestStreamReactor_ReactorClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send some data first + testData := make([]byte, 128) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + + // Close the reactor + reactor.Close() + + // Close connections + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback after reactor close") + } +} + +// Helper function to create a connected TCP pair +func createTCPPair(t *testing.T) (net.Conn, net.Conn) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + var serverConn net.Conn + var serverErr error + serverDone := make(chan struct{}) + go func() { + serverConn, serverErr = listener.Accept() + close(serverDone) + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + + <-serverDone + require.NoError(t, serverErr) + + return serverConn, clientConn +} diff --git a/common/network/packet_pollable.go b/common/network/packet_pollable.go new file mode 100644 index 00000000..25ec095a --- /dev/null +++ b/common/network/packet_pollable.go @@ -0,0 +1,20 @@ +package network + +type PacketPollMode int + +const ( + PacketPollModeChannel PacketPollMode = iota + PacketPollModeFD +) + +// PacketPollable provides polling support for packet connections +type PacketPollable interface { + PollMode() PacketPollMode + PacketChannel() <-chan *PacketBuffer + FD() int +} + +// PacketPollableCreator creates a PacketPollable dynamically +type PacketPollableCreator interface { + CreatePacketPollable() (PacketPollable, bool) +} diff --git a/common/network/read_notifier.go b/common/network/read_notifier.go deleted file mode 100644 index 3a693b15..00000000 --- a/common/network/read_notifier.go +++ /dev/null @@ -1,21 +0,0 @@ -package network - -type ReadNotifier interface { - isReadNotifier() -} - -type ChannelNotifier struct { - Channel <-chan *PacketBuffer -} - -func (*ChannelNotifier) isReadNotifier() {} - -type FileDescriptorNotifier struct { - FD int -} - -func (*FileDescriptorNotifier) isReadNotifier() {} - -type ReadNotifierSource interface { - CreateReadNotifier() ReadNotifier -} diff --git a/common/network/stream_pollable.go b/common/network/stream_pollable.go new file mode 100644 index 00000000..6dfd58da --- /dev/null +++ b/common/network/stream_pollable.go @@ -0,0 +1,17 @@ +package network + +// StreamPollable provides reactor support for TCP stream connections +// Used by StreamReactor for idle detection via epoll/kqueue/IOCP +type StreamPollable interface { + // FD returns the file descriptor for reactor registration + FD() int + // Buffered returns the number of bytes in internal buffer + // Reactor must check this before returning to idle state + Buffered() int +} + +// StreamPollableCreator creates a StreamPollable dynamically +// Optional interface - prefer direct implementation of StreamPollable +type StreamPollableCreator interface { + CreateStreamPollable() (StreamPollable, bool) +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 81710bfc..081cfc1d 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -141,6 +141,14 @@ func (c *natConn) Upstream() any { return c.writer } -func (c *natConn) CreateReadNotifier() N.ReadNotifier { - return &N.ChannelNotifier{Channel: c.packetChan} +func (c *natConn) PollMode() N.PacketPollMode { + return N.PacketPollModeChannel +} + +func (c *natConn) PacketChannel() <-chan *N.PacketBuffer { + return c.packetChan +} + +func (c *natConn) FD() int { + return -1 } From 5820f0e0e910d6ffe1486e5d01e95cfa136ad108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 16:20:33 +0800 Subject: [PATCH 03/13] Fix reactor race conditions and resource leaks --- common/bufio/fd_poller_darwin.go | 70 +++++++++++++++++++------------ common/bufio/fd_poller_windows.go | 38 ++++++++++++++++- common/bufio/stream_reactor.go | 8 +++- 3 files changed, 85 insertions(+), 31 deletions(-) diff --git a/common/bufio/fd_poller_darwin.go b/common/bufio/fd_poller_darwin.go index fd2422f9..3f97e903 100644 --- a/common/bufio/fd_poller_darwin.go +++ b/common/bufio/fd_poller_darwin.go @@ -6,25 +6,29 @@ import ( "context" "sync" "sync/atomic" + "unsafe" "golang.org/x/sys/unix" ) type fdDemuxEntry struct { - fd int - handler FDHandler + fd int + registrationID uint64 + handler FDHandler } type FDPoller struct { - ctx context.Context - cancel context.CancelFunc - kqueueFD int - mutex sync.Mutex - entries map[int]*fdDemuxEntry - running bool - closed atomic.Bool - wg sync.WaitGroup - pipeFDs [2]int + ctx context.Context + cancel context.CancelFunc + kqueueFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + registrationToFD map[uint64]int + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int } func NewFDPoller(ctx context.Context) (*FDPoller, error) { @@ -69,11 +73,12 @@ func NewFDPoller(ctx context.Context) (*FDPoller, error) { ctx, cancel := context.WithCancel(ctx) poller := &FDPoller{ - ctx: ctx, - cancel: cancel, - kqueueFD: kqueueFD, - entries: make(map[int]*fdDemuxEntry), - pipeFDs: pipeFDs, + ctx: ctx, + cancel: cancel, + kqueueFD: kqueueFD, + entries: make(map[int]*fdDemuxEntry), + registrationToFD: make(map[uint64]int), + pipeFDs: pipeFDs, } return poller, nil } @@ -86,20 +91,26 @@ func (p *FDPoller) Add(handler FDHandler, fd int) error { return unix.EINVAL } + p.registrationCounter++ + registrationID := p.registrationCounter + _, err := unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ Ident: uint64(fd), Filter: unix.EVFILT_READ, - Flags: unix.EV_ADD, + Flags: unix.EV_ADD | unix.EV_ONESHOT, + Udata: (*byte)(unsafe.Pointer(uintptr(registrationID))), }}, nil, nil) if err != nil { return err } entry := &fdDemuxEntry{ - fd: fd, - handler: handler, + fd: fd, + registrationID: registrationID, + handler: handler, } p.entries[fd] = entry + p.registrationToFD[registrationID] = fd if !p.running { p.running = true @@ -114,7 +125,7 @@ func (p *FDPoller) Remove(fd int) { p.mutex.Lock() defer p.mutex.Unlock() - _, ok := p.entries[fd] + entry, ok := p.entries[fd] if !ok { return } @@ -124,6 +135,7 @@ func (p *FDPoller) Remove(fd int) { Filter: unix.EVFILT_READ, Flags: unix.EV_DELETE, }}, nil, nil) + delete(p.registrationToFD, entry.registrationID) delete(p.entries, fd) } @@ -196,18 +208,22 @@ func (p *FDPoller) run() { continue } + registrationID := uint64(uintptr(unsafe.Pointer(event.Udata))) + p.mutex.Lock() - entry, ok := p.entries[fd] - if !ok { + mappedFD, ok := p.registrationToFD[registrationID] + if !ok || mappedFD != fd { + p.mutex.Unlock() + continue + } + + entry := p.entries[fd] + if entry == nil || entry.registrationID != registrationID { p.mutex.Unlock() continue } - unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ - Ident: uint64(fd), - Filter: unix.EVFILT_READ, - Flags: unix.EV_DELETE, - }}, nil, nil) + delete(p.registrationToFD, registrationID) delete(p.entries, fd) p.mutex.Unlock() diff --git a/common/bufio/fd_poller_windows.go b/common/bufio/fd_poller_windows.go index 434b38b7..3dedeec5 100644 --- a/common/bufio/fd_poller_windows.go +++ b/common/bufio/fd_poller_windows.go @@ -22,6 +22,7 @@ type fdDemuxEntry struct { baseHandle windows.Handle registrationID uint64 cancelled bool + unpinned bool pinner wepoll.Pinner } @@ -138,7 +139,10 @@ func (p *FDPoller) Close() error { defer p.mutex.Unlock() for fd, entry := range p.entries { - entry.pinner.Unpin() + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } delete(p.entries, fd) } @@ -153,6 +157,32 @@ func (p *FDPoller) Close() error { return nil } +func (p *FDPoller) drainCompletions(completions []wepoll.OverlappedEntry) { + for { + var numRemoved uint32 + err := wepoll.GetQueuedCompletionStatusEx(p.iocp, &completions[0], uint32(len(completions)), &numRemoved, 0, false) + if err != nil || numRemoved == 0 { + break + } + + for i := uint32(0); i < numRemoved; i++ { + event := completions[i] + if event.Overlapped == nil { + continue + } + + entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) + p.mutex.Lock() + if p.entries[entry.fd] == entry && !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, entry.fd) + p.mutex.Unlock() + } + } +} + func (p *FDPoller) run() { defer p.wg.Done() @@ -161,6 +191,7 @@ func (p *FDPoller) run() { for { select { case <-p.ctx.Done(): + p.drainCompletions(completions) p.mutex.Lock() p.running = false p.mutex.Unlock() @@ -193,7 +224,10 @@ func (p *FDPoller) run() { continue } - entry.pinner.Unpin() + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } delete(p.entries, entry.fd) if entry.cancelled { diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index c5341aec..aaebc87a 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -331,9 +331,13 @@ func (d *streamDirection) handleEOFOrError(err error) { // Try half-close on destination if d.isUpload { - N.CloseWrite(d.connection.download.originSource) + if d.connection.download != nil { + N.CloseWrite(d.connection.download.originSource) + } } else { - N.CloseWrite(d.connection.upload.originSource) + if d.connection.upload != nil { + N.CloseWrite(d.connection.upload.originSource) + } } d.connection.checkBothClosed() From 91d05d26c3566eac817bd1a10afaa11e0088c882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 17:28:35 +0800 Subject: [PATCH 04/13] Improve packet reactor --- common/bufio/channel_poller.go | 143 ---------------------------- common/bufio/packet_reactor.go | 118 +++++++++++------------ common/bufio/packet_reactor_test.go | 10 +- common/network/packet_pollable.go | 19 ++-- common/udpnat2/conn.go | 121 +++++++++++++++-------- common/udpnat2/service.go | 18 +--- 6 files changed, 150 insertions(+), 279 deletions(-) delete mode 100644 common/bufio/channel_poller.go diff --git a/common/bufio/channel_poller.go b/common/bufio/channel_poller.go deleted file mode 100644 index ceece6a2..00000000 --- a/common/bufio/channel_poller.go +++ /dev/null @@ -1,143 +0,0 @@ -package bufio - -import ( - "context" - "reflect" - "sync" - "sync/atomic" - - N "github.com/sagernet/sing/common/network" -) - -type channelDemuxEntry struct { - channel <-chan *N.PacketBuffer - stream *reactorStream -} - -type ChannelPoller struct { - ctx context.Context - cancel context.CancelFunc - mutex sync.Mutex - entries map[<-chan *N.PacketBuffer]*channelDemuxEntry - updateChan chan struct{} - running bool - closed atomic.Bool - wg sync.WaitGroup -} - -func NewChannelPoller(ctx context.Context) *ChannelPoller { - ctx, cancel := context.WithCancel(ctx) - poller := &ChannelPoller{ - ctx: ctx, - cancel: cancel, - entries: make(map[<-chan *N.PacketBuffer]*channelDemuxEntry), - updateChan: make(chan struct{}, 1), - } - return poller -} - -func (p *ChannelPoller) Add(stream *reactorStream, channel <-chan *N.PacketBuffer) { - p.mutex.Lock() - - if p.closed.Load() { - p.mutex.Unlock() - return - } - - entry := &channelDemuxEntry{ - channel: channel, - stream: stream, - } - p.entries[channel] = entry - if !p.running { - p.running = true - p.wg.Add(1) - go p.run() - } - p.mutex.Unlock() - p.signalUpdate() -} - -func (p *ChannelPoller) Remove(channel <-chan *N.PacketBuffer) { - p.mutex.Lock() - delete(p.entries, channel) - p.mutex.Unlock() - p.signalUpdate() -} - -func (p *ChannelPoller) signalUpdate() { - select { - case p.updateChan <- struct{}{}: - default: - } -} - -func (p *ChannelPoller) Close() error { - p.mutex.Lock() - p.closed.Store(true) - p.mutex.Unlock() - - p.cancel() - p.signalUpdate() - p.wg.Wait() - return nil -} - -func (p *ChannelPoller) run() { - defer p.wg.Done() - - for { - p.mutex.Lock() - if len(p.entries) == 0 { - p.running = false - p.mutex.Unlock() - return - } - - cases := make([]reflect.SelectCase, 0, len(p.entries)+2) - - cases = append(cases, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(p.ctx.Done()), - }) - - cases = append(cases, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(p.updateChan), - }) - - entryList := make([]*channelDemuxEntry, 0, len(p.entries)) - for _, entry := range p.entries { - cases = append(cases, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(entry.channel), - }) - entryList = append(entryList, entry) - } - p.mutex.Unlock() - - chosen, recv, recvOK := reflect.Select(cases) - - switch chosen { - case 0: - p.mutex.Lock() - p.running = false - p.mutex.Unlock() - return - case 1: - continue - default: - entry := entryList[chosen-2] - p.mutex.Lock() - delete(p.entries, entry.channel) - p.mutex.Unlock() - - if recvOK { - packet := recv.Interface().(*N.PacketBuffer) - go entry.stream.runActiveLoop(packet) - } else { - go entry.stream.closeWithError(nil) - } - } - } -} diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index 15398b9a..8c8a6f0a 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -24,20 +24,18 @@ const ( ) type PacketReactor struct { - ctx context.Context - cancel context.CancelFunc - channelPoller *ChannelPoller - fdPoller *FDPoller - fdPollerOnce sync.Once - fdPollerErr error + ctx context.Context + cancel context.CancelFunc + fdPoller *FDPoller + fdPollerOnce sync.Once + fdPollerErr error } func NewPacketReactor(ctx context.Context) *PacketReactor { ctx, cancel := context.WithCancel(ctx) return &PacketReactor{ - ctx: ctx, - cancel: cancel, - channelPoller: NewChannelPoller(ctx), + ctx: ctx, + cancel: cancel, } } @@ -50,14 +48,10 @@ func (r *PacketReactor) getFDPoller() (*FDPoller, error) { func (r *PacketReactor) Close() error { r.cancel() - var errs []error - if r.channelPoller != nil { - errs = append(errs, r.channelPoller.Close()) - } if r.fdPoller != nil { - errs = append(errs, r.fdPoller.Close()) + return r.fdPoller.Close() } - return E.Errors(errs...) + return nil } type reactorConnection struct { @@ -80,6 +74,7 @@ type reactorStream struct { destination N.PacketWriter originSource N.PacketReader + pushable N.PacketPushable pollable N.PacketPollable options N.ReadWaitOptions readWaiter N.PacketReadWaiter @@ -159,7 +154,9 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe stream.readWaiter.InitializeReadWaiter(stream.options) } - if pollable, ok := source.(N.PacketPollable); ok { + if pushable, ok := source.(N.PacketPushable); ok { + stream.pushable = pushable + } else if pollable, ok := source.(N.PacketPollable); ok { stream.pollable = pollable } else if creator, ok := source.(N.PacketPollableCreator); ok { stream.pollable, _ = creator.CreatePacketPollable() @@ -169,25 +166,32 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe } func (r *PacketReactor) registerStream(stream *reactorStream) { + if stream.pushable != nil { + stream.pushable.SetOnDataReady(func() { + if stream.state.CompareAndSwap(stateIdle, stateActive) { + go stream.runActiveLoop(nil) + } + }) + if stream.pushable.HasPendingData() { + if stream.state.CompareAndSwap(stateIdle, stateActive) { + go stream.runActiveLoop(nil) + } + } + return + } + if stream.pollable == nil { go stream.runLegacyCopy() return } - switch stream.pollable.PollMode() { - case N.PacketPollModeChannel: - r.channelPoller.Add(stream, stream.pollable.PacketChannel()) - case N.PacketPollModeFD: - fdPoller, err := r.getFDPoller() - if err != nil { - go stream.runLegacyCopy() - return - } - err = fdPoller.Add(stream, stream.pollable.FD()) - if err != nil { - go stream.runLegacyCopy() - } - default: + fdPoller, err := r.getFDPoller() + if err != nil { + go stream.runLegacyCopy() + return + } + err = fdPoller.Add(stream, stream.pollable.FD()) + if err != nil { go stream.runLegacyCopy() } } @@ -259,9 +263,18 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { setter.SetReadDeadline(time.Time{}) } - if s.state.CompareAndSwap(stateActive, stateIdle) { - s.returnToPool() + if !s.state.CompareAndSwap(stateActive, stateIdle) { + return + } + if s.pushable != nil { + if s.pushable.HasPendingData() { + if s.state.CompareAndSwap(stateIdle, stateActive) { + continue + } + } + return } + s.returnToPool() return } if !notFirstTime { @@ -310,30 +323,18 @@ func (s *reactorStream) returnToPool() { return } - if s.pollable == nil { + if s.pollable == nil || s.connection.reactor.fdPoller == nil { return } - switch s.pollable.PollMode() { - case N.PacketPollModeChannel: - channel := s.pollable.PacketChannel() - s.connection.reactor.channelPoller.Add(s, channel) - if s.state.Load() != stateIdle { - s.connection.reactor.channelPoller.Remove(channel) - } - case N.PacketPollModeFD: - if s.connection.reactor.fdPoller == nil { - return - } - fd := s.pollable.FD() - err := s.connection.reactor.fdPoller.Add(s, fd) - if err != nil { - s.closeWithError(err) - return - } - if s.state.Load() != stateIdle { - s.connection.reactor.fdPoller.Remove(fd) - } + fd := s.pollable.FD() + err := s.connection.reactor.fdPoller.Add(s, fd) + if err != nil { + s.closeWithError(err) + return + } + if s.state.Load() != stateIdle { + s.connection.reactor.fdPoller.Remove(fd) } } @@ -385,15 +386,8 @@ func (c *reactorConnection) removeFromPollers() { } func (c *reactorConnection) removeStreamFromPoller(stream *reactorStream) { - if stream == nil || stream.pollable == nil { + if stream == nil || stream.pollable == nil || c.reactor.fdPoller == nil { return } - switch stream.pollable.PollMode() { - case N.PacketPollModeChannel: - c.reactor.channelPoller.Remove(stream.pollable.PacketChannel()) - case N.PacketPollModeFD: - if c.reactor.fdPoller != nil { - c.reactor.fdPoller.Remove(stream.pollable.FD()) - } - } + c.reactor.fdPoller.Remove(stream.pollable.FD()) } diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go index ea95b077..064865da 100644 --- a/common/bufio/packet_reactor_test.go +++ b/common/bufio/packet_reactor_test.go @@ -103,10 +103,6 @@ func (p *testPacketPipe) SetWriteDeadline(t time.Time) error { return nil } -func (p *testPacketPipe) PacketChannel() <-chan *N.PacketBuffer { - return p.inChan -} - func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) { packet := N.NewPacketBuffer() newBuf := buf.NewSize(len(data)) @@ -255,10 +251,6 @@ func (c *channelPacketConn) SetReadDeadline(t time.Time) error { return nil } -func (c *channelPacketConn) PacketChannel() <-chan *N.PacketBuffer { - return c.packetChan -} - func (c *channelPacketConn) Close() error { c.closeOnce.Do(func() { close(c.done) @@ -551,7 +543,7 @@ func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) { assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") } -func TestBatchCopy_ChannelPoller_DataIntegrity(t *testing.T) { +func TestBatchCopy_LegacyChannel_DataIntegrity(t *testing.T) { t.Parallel() clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") diff --git a/common/network/packet_pollable.go b/common/network/packet_pollable.go index 25ec095a..80fb9f65 100644 --- a/common/network/packet_pollable.go +++ b/common/network/packet_pollable.go @@ -1,20 +1,19 @@ package network -type PacketPollMode int - -const ( - PacketPollModeChannel PacketPollMode = iota - PacketPollModeFD -) +// PacketPushable represents a packet source that receives pushed data +// from external code and notifies reactor via callback. +type PacketPushable interface { + SetOnDataReady(callback func()) + HasPendingData() bool +} -// PacketPollable provides polling support for packet connections +// PacketPollable provides FD-based polling for packet connections. +// Mirrors StreamPollable for consistency. type PacketPollable interface { - PollMode() PacketPollMode - PacketChannel() <-chan *PacketBuffer FD() int } -// PacketPollableCreator creates a PacketPollable dynamically +// PacketPollableCreator creates a PacketPollable dynamically. type PacketPollableCreator interface { CreatePacketPollable() (PacketPollable, bool) } diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 081cfc1d..300e4f35 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -13,7 +13,6 @@ import ( "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/contrab/freelru" ) @@ -23,7 +22,11 @@ type Conn interface { canceler.PacketConn } -var _ Conn = (*natConn)(nil) +var ( + _ Conn = (*natConn)(nil) + _ N.PacketPushable = (*natConn)(nil) + _ N.PacketReadWaiter = (*natConn)(nil) +) type natConn struct { cache freelru.Cache[netip.AddrPort, *natConn] @@ -31,26 +34,37 @@ type natConn struct { localAddr M.Socksaddr handlerAccess sync.RWMutex handler N.UDPHandlerEx - packetChan chan *N.PacketBuffer - closeOnce sync.Once - doneChan chan struct{} - readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions + + dataQueue []*N.PacketBuffer + queueMutex sync.Mutex + onDataReady func() + + closeOnce sync.Once + doneChan chan struct{} } func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { select { - case p := <-c.packetChan: - _, err = buffer.ReadOnceFrom(p.Buffer) - destination := p.Destination - p.Buffer.Release() - N.PutPacketBuffer(p) - return destination, err case <-c.doneChan: return M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): + default: + } + + c.queueMutex.Lock() + if len(c.dataQueue) == 0 { + c.queueMutex.Unlock() return M.Socksaddr{}, os.ErrDeadlineExceeded } + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination := packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err } func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { @@ -66,16 +80,24 @@ func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { select { - case packet := <-c.packetChan: - buffer = c.readWaitOptions.Copy(packet.Buffer) - destination = packet.Destination - N.PutPacketBuffer(packet) - return case <-c.doneChan: return nil, M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): + default: + } + + c.queueMutex.Lock() + if len(c.dataQueue) == 0 { + c.queueMutex.Unlock() return nil, M.Socksaddr{}, os.ErrDeadlineExceeded } + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + N.PutPacketBuffer(packet) + return } func (c *natConn) SetHandler(handler N.UDPHandlerEx) { @@ -83,16 +105,44 @@ func (c *natConn) SetHandler(handler N.UDPHandlerEx) { c.handler = handler c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) c.handlerAccess.Unlock() -fetch: - for { - select { - case packet := <-c.packetChan: - c.handler.NewPacketEx(packet.Buffer, packet.Destination) - N.PutPacketBuffer(packet) - continue fetch - default: - break fetch - } + + c.queueMutex.Lock() + pending := c.dataQueue + c.dataQueue = nil + c.queueMutex.Unlock() + + for _, packet := range pending { + handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + } +} + +func (c *natConn) SetOnDataReady(callback func()) { + c.queueMutex.Lock() + c.onDataReady = callback + c.queueMutex.Unlock() +} + +func (c *natConn) HasPendingData() bool { + c.queueMutex.Lock() + defer c.queueMutex.Unlock() + return len(c.dataQueue) > 0 +} + +func (c *natConn) PushPacket(packet *N.PacketBuffer) { + c.queueMutex.Lock() + if len(c.dataQueue) >= 64 { + c.queueMutex.Unlock() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + } + c.dataQueue = append(c.dataQueue, packet) + callback := c.onDataReady + c.queueMutex.Unlock() + + if callback != nil { + callback() } } @@ -129,7 +179,6 @@ func (c *natConn) SetDeadline(t time.Time) error { } func (c *natConn) SetReadDeadline(t time.Time) error { - c.readDeadline.Set(t) return nil } @@ -140,15 +189,3 @@ func (c *natConn) SetWriteDeadline(t time.Time) error { func (c *natConn) Upstream() any { return c.writer } - -func (c *natConn) PollMode() N.PacketPollMode { - return N.PacketPollModeChannel -} - -func (c *natConn) PacketChannel() <-chan *N.PacketBuffer { - return c.packetChan -} - -func (c *natConn) FD() int { - return -1 -} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 3e3ce7d1..3cf7392d 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -8,7 +8,6 @@ import ( "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" ) @@ -57,12 +56,10 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati return nil, false } newConn := &natConn{ - cache: s.cache, - writer: writer, - localAddr: source, - packetChan: make(chan *N.PacketBuffer, 64), - doneChan: make(chan struct{}), - readDeadline: pipe.MakeDeadline(), + cache: s.cache, + writer: writer, + localAddr: source, + doneChan: make(chan struct{}), } go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) return newConn, true @@ -87,12 +84,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati Buffer: buffer, Destination: destination, } - select { - case conn.packetChan <- packet: - default: - packet.Buffer.Release() - N.PutPacketBuffer(packet) - } + conn.PushPacket(packet) } func (s *Service) Purge() { From d2f7a9c8e00479212474a2a46ecda6cf6f1dcdc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 17:59:07 +0800 Subject: [PATCH 05/13] Fix reactor issues and add context propagation --- common/bufio/fd_poller_windows.go | 6 +++++ common/bufio/packet_reactor.go | 37 ++++++++++++++++++--------- common/bufio/stream_reactor.go | 39 +++++++++++++++++++--------- common/context_afterfunc.go | 11 ++++++++ common/context_afterfunc_compat.go | 41 ++++++++++++++++++++++++++++++ common/network/handshake.go | 2 +- common/udpnat2/conn.go | 8 ++++++ 7 files changed, 119 insertions(+), 25 deletions(-) create mode 100644 common/context_afterfunc.go create mode 100644 common/context_afterfunc_compat.go diff --git a/common/bufio/fd_poller_windows.go b/common/bufio/fd_poller_windows.go index 3dedeec5..a1cb368a 100644 --- a/common/bufio/fd_poller_windows.go +++ b/common/bufio/fd_poller_windows.go @@ -120,6 +120,12 @@ func (p *FDPoller) Remove(fd int) { if p.afd != nil { p.afd.Cancel(&entry.ioStatusBlock) } + + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, fd) } func (p *FDPoller) wakeup() { diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index 8c8a6f0a..86044e4b 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -55,12 +55,13 @@ func (r *PacketReactor) Close() error { } type reactorConnection struct { - ctx context.Context - cancel context.CancelFunc - reactor *PacketReactor - onClose N.CloseHandlerFunc - upload *reactorStream - download *reactorStream + ctx context.Context + cancel context.CancelFunc + reactor *PacketReactor + onClose N.CloseHandlerFunc + upload *reactorStream + download *reactorStream + stopReactorWatch func() bool closeOnce sync.Once done chan struct{} @@ -93,6 +94,9 @@ func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destinati onClose: onClose, done: make(chan struct{}), } + conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() { + conn.closeWithError(r.ctx.Err()) + }) conn.upload = r.prepareStream(conn, source, destination) select { @@ -126,10 +130,12 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe if cachedReader, isCached := source.(N.CachedPacketReader); isCached { packet := cachedReader.ReadCachedPacket() if packet != nil { - dataLen := packet.Buffer.Len() - err := destination.WritePacket(packet.Buffer, packet.Destination) + buffer := packet.Buffer + dataLen := buffer.Len() + err := destination.WritePacket(buffer, packet.Destination) N.PutPacketBuffer(packet) if err != nil { + buffer.Leak() conn.closeWithError(err) return stream } @@ -151,7 +157,10 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe stream.readWaiter, _ = CreatePacketReadWaiter(source) if stream.readWaiter != nil { - stream.readWaiter.InitializeReadWaiter(stream.options) + needCopy := stream.readWaiter.InitializeReadWaiter(stream.options) + if needCopy { + stream.readWaiter = nil + } } if pushable, ok := source.(N.PacketPushable); ok { @@ -343,7 +352,7 @@ func (s *reactorStream) HandleFDEvent() { } func (s *reactorStream) runLegacyCopy() { - _, err := CopyPacket(s.destination, s.source) + _, err := CopyPacketWithCounters(s.destination, s.source, s.originSource, s.readCounters, s.writeCounters) s.closeWithError(err) } @@ -353,6 +362,12 @@ func (s *reactorStream) closeWithError(err error) { func (c *reactorConnection) closeWithError(err error) { c.closeOnce.Do(func() { + defer close(c.done) + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + c.err = err c.cancel() @@ -375,8 +390,6 @@ func (c *reactorConnection) closeWithError(err error) { if c.onClose != nil { c.onClose(c.err) } - - close(c.done) }) } diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index aaebc87a..aceeb4aa 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -50,12 +50,13 @@ func (r *StreamReactor) Close() error { } type streamConnection struct { - ctx context.Context - cancel context.CancelFunc - reactor *StreamReactor - onClose N.CloseHandlerFunc - upload *streamDirection - download *streamDirection + ctx context.Context + cancel context.CancelFunc + reactor *StreamReactor + onClose N.CloseHandlerFunc + upload *streamDirection + download *streamDirection + stopReactorWatch func() bool closeOnce sync.Once done chan struct{} @@ -96,6 +97,9 @@ func (r *StreamReactor) Copy(ctx context.Context, source net.Conn, destination n onClose: onClose, done: make(chan struct{}), } + conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() { + conn.closeWithError(r.ctx.Err()) + }) conn.upload = r.prepareDirection(conn, source, destination, source, true) select { @@ -171,7 +175,10 @@ func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reade direction.readWaiter, _ = CreateReadWaiter(source) if direction.readWaiter != nil { - direction.readWaiter.InitializeReadWaiter(direction.options) + needCopy := direction.readWaiter.InitializeReadWaiter(direction.options) + if needCopy { + direction.readWaiter = nil + } } // Try to get stream pollable for FD-based idle detection @@ -320,7 +327,7 @@ func (d *streamDirection) HandleFDEvent() { } func (d *streamDirection) runLegacyCopy() { - _, err := Copy(d.destination, d.source) + _, err := CopyWithCounters(d.destination, d.source, d.originSource, d.readCounters, d.writeCounters, DefaultIncreaseBufferAfter, DefaultBatchSize) d.handleEOFOrError(err) } @@ -358,6 +365,12 @@ func (c *streamConnection) checkBothClosed() { if uploadClosed && downloadClosed { c.closeOnce.Do(func() { + defer close(c.done) + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + c.cancel() c.removeFromPoller() @@ -367,14 +380,18 @@ func (c *streamConnection) checkBothClosed() { if c.onClose != nil { c.onClose(nil) } - - close(c.done) }) } } func (c *streamConnection) closeWithError(err error) { c.closeOnce.Do(func() { + defer close(c.done) + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + c.err = err c.cancel() @@ -397,8 +414,6 @@ func (c *streamConnection) closeWithError(err error) { if c.onClose != nil { c.onClose(c.err) } - - close(c.done) }) } diff --git a/common/context_afterfunc.go b/common/context_afterfunc.go new file mode 100644 index 00000000..887dc148 --- /dev/null +++ b/common/context_afterfunc.go @@ -0,0 +1,11 @@ +//go:build go1.21 + +package common + +import "context" + +// ContextAfterFunc arranges to call f in its own goroutine after ctx is done. +// Returns a stop function that prevents f from being run. +func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) { + return context.AfterFunc(ctx, f) +} diff --git a/common/context_afterfunc_compat.go b/common/context_afterfunc_compat.go new file mode 100644 index 00000000..0a09ced7 --- /dev/null +++ b/common/context_afterfunc_compat.go @@ -0,0 +1,41 @@ +//go:build go1.20 && !go1.21 + +package common + +import ( + "context" + "sync" +) + +// ContextAfterFunc arranges to call f in its own goroutine after ctx is done. +// Returns a stop function that prevents f from being run. +func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) { + stopCh := make(chan struct{}) + var once sync.Once + stopped := false + + go func() { + select { + case <-ctx.Done(): + once.Do(func() { + if !stopped { + f() + } + }) + case <-stopCh: + } + }() + + return func() bool { + select { + case <-ctx.Done(): + return false + default: + stopped = true + once.Do(func() { + close(stopCh) + }) + return true + } + } +} diff --git a/common/network/handshake.go b/common/network/handshake.go index c86a1b57..a032ab6f 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -30,7 +30,7 @@ func ReportHandshakeFailure(reporter any, err error) error { return E.Cause(err, "write handshake failure") }) } - return nil + return err } func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error { diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 300e4f35..cca511d3 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -161,6 +161,14 @@ func (c *natConn) SetTimeout(timeout time.Duration) bool { func (c *natConn) Close() error { c.closeOnce.Do(func() { close(c.doneChan) + + c.queueMutex.Lock() + pending := c.dataQueue + c.dataQueue = nil + c.onDataReady = nil + c.queueMutex.Unlock() + + N.ReleaseMultiPacketBuffer(pending) common.Close(c.handler) }) return nil From 4353aac964b0b49a4956216cc976d7c24f74774a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 18:21:06 +0800 Subject: [PATCH 06/13] Fix Windows test failures --- common/bufio/fd_demux_windows_test.go | 24 ++++++++++++------------ common/wepoll/afd_windows.go | 1 + 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/common/bufio/fd_demux_windows_test.go b/common/bufio/fd_demux_windows_test.go index 030a3dfb..4f7ccce4 100644 --- a/common/bufio/fd_demux_windows_test.go +++ b/common/bufio/fd_demux_windows_test.go @@ -28,7 +28,7 @@ func getSocketFD(t *testing.T, conn net.PacketConn) int { func TestFDDemultiplexer_Create(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) err = demux.Close() @@ -38,11 +38,11 @@ func TestFDDemultiplexer_Create(t *testing.T) { func TestFDDemultiplexer_CreateMultiple(t *testing.T) { t.Parallel() - demux1, err := NewFDDemultiplexer(context.Background()) + demux1, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux1.Close() - demux2, err := NewFDDemultiplexer(context.Background()) + demux2, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux2.Close() } @@ -50,7 +50,7 @@ func TestFDDemultiplexer_CreateMultiple(t *testing.T) { func TestFDDemultiplexer_AddRemove(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux.Close() @@ -71,7 +71,7 @@ func TestFDDemultiplexer_AddRemove(t *testing.T) { func TestFDDemultiplexer_RapidAddRemove(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux.Close() @@ -95,7 +95,7 @@ func TestFDDemultiplexer_RapidAddRemove(t *testing.T) { func TestFDDemultiplexer_ConcurrentAccess(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux.Close() @@ -136,7 +136,7 @@ func TestFDDemultiplexer_ReceiveEvent(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - demux, err := NewFDDemultiplexer(ctx) + demux, err := NewFDPoller(ctx) require.NoError(t, err) defer demux.Close() @@ -182,7 +182,7 @@ func TestFDDemultiplexer_ReceiveEvent(t *testing.T) { func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) conn, err := net.ListenPacket("udp", "127.0.0.1:0") @@ -213,7 +213,7 @@ func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) { func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux.Close() @@ -223,7 +223,7 @@ func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) { func TestFDDemultiplexer_AddAfterClose(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) err = demux.Close() @@ -243,7 +243,7 @@ func TestFDDemultiplexer_AddAfterClose(t *testing.T) { func TestFDDemultiplexer_MultipleSocketsSimultaneous(t *testing.T) { t.Parallel() - demux, err := NewFDDemultiplexer(context.Background()) + demux, err := NewFDPoller(context.Background()) require.NoError(t, err) defer demux.Close() @@ -274,7 +274,7 @@ func TestFDDemultiplexer_ContextCancellation(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) - demux, err := NewFDDemultiplexer(ctx) + demux, err := NewFDPoller(ctx) require.NoError(t, err) conn, err := net.ListenPacket("udp", "127.0.0.1:0") diff --git a/common/wepoll/afd_windows.go b/common/wepoll/afd_windows.go index aad3391a..1547cc8b 100644 --- a/common/wepoll/afd_windows.go +++ b/common/wepoll/afd_windows.go @@ -75,6 +75,7 @@ func (a *AFD) Poll(baseSocket windows.Handle, events uint32, iosb *windows.IO_ST pollInfo.Handles[0].Events = events pollInfo.Handles[0].Status = 0 + iosb.Status = windows.NTStatus(STATUS_PENDING) size := uint32(unsafe.Sizeof(*pollInfo)) err := NtDeviceIoControlFile( From b79cddcdf2e5de50039c41432f3d2be194a99264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 19:17:32 +0800 Subject: [PATCH 07/13] Add Create API for reactor pollables --- common/bufio/packet_reactor.go | 52 ++++++++++++++++++++++++++++++---- common/bufio/stream_reactor.go | 25 ++++++++++++---- 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index 86044e4b..f9508289 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -23,6 +23,49 @@ const ( stateClosed int32 = 2 ) +func CreatePacketPushable(reader N.PacketReader) (N.PacketPushable, bool) { + if pushable, ok := reader.(N.PacketPushable); ok { + return pushable, true + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + if upstream, ok := u.UpstreamReader().(N.PacketReader); ok { + return CreatePacketPushable(upstream) + } + } + if u, ok := reader.(common.WithUpstream); ok { + if upstream, ok := u.Upstream().(N.PacketReader); ok { + return CreatePacketPushable(upstream) + } + } + return nil, false +} + +func CreatePacketPollable(reader N.PacketReader) (N.PacketPollable, bool) { + if pollable, ok := reader.(N.PacketPollable); ok { + return pollable, true + } + if creator, ok := reader.(N.PacketPollableCreator); ok { + return creator.CreatePacketPollable() + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + if upstream, ok := u.UpstreamReader().(N.PacketReader); ok { + return CreatePacketPollable(upstream) + } + } + if u, ok := reader.(common.WithUpstream); ok { + if upstream, ok := u.Upstream().(N.PacketReader); ok { + return CreatePacketPollable(upstream) + } + } + return nil, false +} + type PacketReactor struct { ctx context.Context cancel context.CancelFunc @@ -163,12 +206,9 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe } } - if pushable, ok := source.(N.PacketPushable); ok { - stream.pushable = pushable - } else if pollable, ok := source.(N.PacketPollable); ok { - stream.pollable = pollable - } else if creator, ok := source.(N.PacketPollableCreator); ok { - stream.pollable, _ = creator.CreatePacketPollable() + stream.pushable, _ = CreatePacketPushable(source) + if stream.pushable == nil { + stream.pollable, _ = CreatePacketPollable(source) } return stream diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index aceeb4aa..d30f7f38 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -18,6 +18,25 @@ const ( streamBatchReadTimeout = 250 * time.Millisecond ) +func CreateStreamPollable(reader io.Reader) (N.StreamPollable, bool) { + if pollable, ok := reader.(N.StreamPollable); ok { + return pollable, true + } + if creator, ok := reader.(N.StreamPollableCreator); ok { + return creator.CreateStreamPollable() + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateStreamPollable(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateStreamPollable(u.Upstream().(io.Reader)) + } + return nil, false +} + type StreamReactor struct { ctx context.Context cancel context.CancelFunc @@ -182,11 +201,7 @@ func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reade } // Try to get stream pollable for FD-based idle detection - if pollable, ok := source.(N.StreamPollable); ok { - direction.pollable = pollable - } else if creator, ok := source.(N.StreamPollableCreator); ok { - direction.pollable, _ = creator.CreateStreamPollable() - } + direction.pollable, _ = CreateStreamPollable(source) return direction } From 99b7f22820bbfca999215270c6842c6c4b269ea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Dec 2025 19:58:08 +0800 Subject: [PATCH 08/13] Improve reactor tests and fix PacketPushable bug --- common/bufio/packet_reactor.go | 8 +- common/bufio/packet_reactor_test.go | 671 ++++++++++++++++++++++++++++ common/bufio/stream_reactor_test.go | 439 ++++++++++++++++++ 3 files changed, 1112 insertions(+), 6 deletions(-) diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index f9508289..3bf8180a 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -217,14 +217,10 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe func (r *PacketReactor) registerStream(stream *reactorStream) { if stream.pushable != nil { stream.pushable.SetOnDataReady(func() { - if stream.state.CompareAndSwap(stateIdle, stateActive) { - go stream.runActiveLoop(nil) - } + go stream.runActiveLoop(nil) }) if stream.pushable.HasPendingData() { - if stream.state.CompareAndSwap(stateIdle, stateActive) { - go stream.runActiveLoop(nil) - } + go stream.runActiveLoop(nil) } return } diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go index 064865da..0fd8d5ed 100644 --- a/common/bufio/packet_reactor_test.go +++ b/common/bufio/packet_reactor_test.go @@ -1475,3 +1475,674 @@ func (c *legacyPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksad } return c.targetAddr, nil } + +// pushablePacketPipe implements PacketPushable for testing with deadline support +type pushablePacketPipe struct { + inChan chan *N.PacketBuffer + outChan chan *N.PacketBuffer + localAddr M.Socksaddr + closed atomic.Bool + closeOnce sync.Once + done chan struct{} + onDataReady func() + mutex sync.Mutex + deadlineLock sync.Mutex + deadline time.Time + deadlineChan chan struct{} +} + +func newPushablePacketPipe(localAddr M.Socksaddr) *pushablePacketPipe { + return &pushablePacketPipe{ + inChan: make(chan *N.PacketBuffer, 256), + outChan: make(chan *N.PacketBuffer, 256), + localAddr: localAddr, + done: make(chan struct{}), + deadlineChan: make(chan struct{}), + } +} + +func (p *pushablePacketPipe) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + p.deadlineLock.Lock() + deadline := p.deadline + deadlineChan := p.deadlineChan + p.deadlineLock.Unlock() + + var timer <-chan time.Time + if !deadline.IsZero() { + d := time.Until(deadline) + if d <= 0 { + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + t := time.NewTimer(d) + defer t.Stop() + timer = t.C + } + + select { + case packet, ok := <-p.inChan: + if !ok { + return M.Socksaddr{}, io.EOF + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err + case <-p.done: + return M.Socksaddr{}, net.ErrClosed + case <-deadlineChan: + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-timer: + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (p *pushablePacketPipe) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if p.closed.Load() { + buffer.Release() + return net.ErrClosed + } + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(buffer.Len()) + newBuf.Write(buffer.Bytes()) + packet.Buffer = newBuf + packet.Destination = destination + buffer.Release() + select { + case p.outChan <- packet: + return nil + case <-p.done: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return net.ErrClosed + } +} + +func (p *pushablePacketPipe) Close() error { + p.closeOnce.Do(func() { + p.closed.Store(true) + close(p.done) + }) + return nil +} + +func (p *pushablePacketPipe) LocalAddr() net.Addr { + return p.localAddr.UDPAddr() +} + +func (p *pushablePacketPipe) SetDeadline(t time.Time) error { + return p.SetReadDeadline(t) +} + +func (p *pushablePacketPipe) SetReadDeadline(t time.Time) error { + p.deadlineLock.Lock() + p.deadline = t + if p.deadlineChan != nil { + close(p.deadlineChan) + } + p.deadlineChan = make(chan struct{}) + p.deadlineLock.Unlock() + return nil +} + +func (p *pushablePacketPipe) SetWriteDeadline(t time.Time) error { + return nil +} + +func (p *pushablePacketPipe) SetOnDataReady(callback func()) { + p.mutex.Lock() + p.onDataReady = callback + p.mutex.Unlock() +} + +func (p *pushablePacketPipe) HasPendingData() bool { + return len(p.inChan) > 0 +} + +func (p *pushablePacketPipe) triggerDataReady() { + p.mutex.Lock() + callback := p.onDataReady + p.mutex.Unlock() + if callback != nil { + callback() + } +} + +func (p *pushablePacketPipe) send(data []byte, destination M.Socksaddr) { + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(len(data)) + newBuf.Write(data) + packet.Buffer = newBuf + packet.Destination = destination + p.inChan <- packet +} + +func (p *pushablePacketPipe) sendWithNotify(data []byte, destination M.Socksaddr) { + p.send(data, destination) + p.triggerDataReady() +} + +func (p *pushablePacketPipe) receive() (*N.PacketBuffer, bool) { + select { + case packet, ok := <-p.outChan: + return packet, ok + case <-p.done: + return nil, false + } +} + +// failingPacketWriter fails after N writes +type failingPacketWriter struct { + N.PacketConn + failAfter int + writeCount atomic.Int32 +} + +func (w *failingPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if w.writeCount.Add(1) > int32(w.failAfter) { + buffer.Release() + return errors.New("simulated packet write error") + } + return w.PacketConn.WritePacket(buffer, destination) +} + +// errorPacketReader returns error on ReadPacket +type errorPacketReader struct { + N.PacketConn + readError error +} + +func (r *errorPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + return M.Socksaddr{}, r.readError +} + +func TestPacketReactor_Pushable_Basic(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61002) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 20 + const chunkSize = 1000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.sendWithNotify(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, sendHash, recvHash, "data mismatch") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_Pushable_HasPendingData(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61011) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61012) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Pre-load data before starting the reactor + const preloadCount = 5 + preloadHashes := make([][]byte, preloadCount) + for i := 0; i < preloadCount; i++ { + data := make([]byte, 100) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + preloadHashes[i] = hash[:] + pipeA.send(data, addr2) + } + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + // Wait for the reactor to detect pending data and process it + receivedHashes := make([][]byte, 0, preloadCount) + timeout := time.After(5 * time.Second) + + for i := 0; i < preloadCount; i++ { + select { + case packet := <-pipeB.outChan: + hash := md5.Sum(packet.Buffer.Bytes()) + receivedHashes = append(receivedHashes, hash[:]) + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-timeout: + t.Fatalf("timeout: only received %d/%d packets", len(receivedHashes), preloadCount) + } + } + + // Verify all preloaded packets were received + assert.Equal(t, len(preloadHashes), len(receivedHashes), "should receive all preloaded packets") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_Pushable_TimeoutResume(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61021) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61022) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send first batch + const batchSize = 5 + for i := 0; i < batchSize; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.sendWithNotify(data, addr2) + } + + // Receive first batch + for i := 0; i < batchSize; i++ { + select { + case packet := <-pipeB.outChan: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-time.After(5 * time.Second): + t.Fatalf("timeout receiving first batch packet %d", i) + } + } + + // Wait longer than the batch timeout (250ms) to trigger return to idle + time.Sleep(400 * time.Millisecond) + + // Send second batch - should trigger data ready callback + for i := 0; i < batchSize; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.sendWithNotify(data, addr2) + } + + // Receive second batch + for i := 0; i < batchSize; i++ { + select { + case packet := <-pipeB.outChan: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-time.After(5 * time.Second): + t.Fatalf("timeout receiving second batch packet %d", i) + } + } + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_WriteError(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61031) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61032) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Wrap destination with failing writer that fails after 3 packets + failingDest := &failingPacketWriter{PacketConn: pipeB, failAfter: 3} + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + closeDone := make(chan struct{}) + + go func() { + copier.Copy(context.Background(), pipeA, failingDest, func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(closeDone) + }) + }() + + time.Sleep(50 * time.Millisecond) + + // Send packets until failure + for i := 0; i < 10; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(10 * time.Millisecond) + } + + // Wait for close callback + select { + case <-closeDone: + errMu.Lock() + err := capturedErr + errMu.Unlock() + assert.NotNil(t, err, "expected error to be propagated") + assert.Contains(t, err.Error(), "simulated packet write error") + case <-time.After(5 * time.Second): + pipeA.Close() + pipeB.Close() + t.Fatal("timeout waiting for close callback") + } +} + +func TestPacketReactor_ReadError(t *testing.T) { + t.Parallel() + + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61042) + + pipeA := newTestPacketPipe(M.ParseSocksaddrHostPort("127.0.0.1", 61041)) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Wrap source with error-returning reader + readErr := errors.New("simulated packet read error") + errorSrc := &errorPacketReader{PacketConn: pipeA, readError: readErr} + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + closeDone := make(chan struct{}) + + go func() { + copier.Copy(context.Background(), errorSrc, pipeB, func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(closeDone) + }) + }() + + select { + case <-closeDone: + errMu.Lock() + err := capturedErr + errMu.Unlock() + assert.NotNil(t, err, "expected error to be propagated") + assert.Contains(t, err.Error(), "simulated packet read error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestPacketReactor_StateMachine_ConcurrentWakeup(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newFDPacketConn(t, proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send many packets rapidly to stress the state machine + const numPackets = 100 + var wg sync.WaitGroup + + // Multiple goroutines sending packets concurrently + for g := 0; g < 5; g++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + for i := 0; i < numPackets/5; i++ { + data := make([]byte, 100) + rand.Read(data) + clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + } + }(g) + } + + wg.Wait() + + // Give time for packets to be processed + time.Sleep(500 * time.Millisecond) + + // Cleanup + proxyAConn.Close() + proxyBConn.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_StateMachine_CloseWhileActive(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61051) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61052) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background()) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Start sending packets continuously + stopSend := make(chan struct{}) + go func() { + for { + select { + case <-stopSend: + return + default: + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + } + }() + + // Give time for active processing + time.Sleep(100 * time.Millisecond) + + // Close while actively processing + pipeA.Close() + pipeB.Close() + copier.Close() + close(stopSend) + + // Verify no deadlock + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("deadlock detected: Copy did not return after close") + } +} + +func TestPacketReactor_Counters(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61061) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61062) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background()) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send packets with known sizes + const numPackets = 10 + packetSizes := []int{100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} + totalBytes := 0 + for _, size := range packetSizes { + totalBytes += size + } + + recvDone := make(chan struct{}) + receivedBytes := 0 + go func() { + defer close(recvDone) + for i := 0; i < numPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + receivedBytes += packet.Buffer.Len() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < numPackets; i++ { + data := make([]byte, packetSizes[i]) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, totalBytes, receivedBytes, "total bytes received should match sent") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} diff --git a/common/bufio/stream_reactor_test.go b/common/bufio/stream_reactor_test.go index 3b3aae18..cfe481c9 100644 --- a/common/bufio/stream_reactor_test.go +++ b/common/bufio/stream_reactor_test.go @@ -4,7 +4,9 @@ package bufio import ( "context" + "crypto/md5" "crypto/rand" + "errors" "io" "net" "sync" @@ -14,6 +16,7 @@ import ( "time" "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -461,3 +464,439 @@ func createTCPPair(t *testing.T) (net.Conn, net.Conn) { return serverConn, clientConn } + +// failingConn wraps net.Conn and fails writes after N calls +type failingConn struct { + net.Conn + failAfter int + writeCount atomic.Int32 +} + +func (c *failingConn) Write(p []byte) (int, error) { + if c.writeCount.Add(1) > int32(c.failAfter) { + return 0, errors.New("simulated write error") + } + return c.Conn.Write(p) +} + +// errorConn returns error on Read +type errorConn struct { + net.Conn + readError error +} + +func (c *errorConn) Read(p []byte) (int, error) { + return 0, c.readError +} + +// countingConn wraps net.Conn with read/write counters +type countingConn struct { + net.Conn + readCount atomic.Int64 + writeCount atomic.Int64 +} + +func (c *countingConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + c.readCount.Add(int64(n)) + return n, err +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + c.writeCount.Add(int64(n)) + return n, err +} + +func (c *countingConn) UnwrapReader() (io.Reader, []N.CountFunc) { + return c.Conn, []N.CountFunc{func(n int64) { c.readCount.Add(n) }} +} + +func (c *countingConn) UnwrapWriter() (io.Writer, []N.CountFunc) { + return c.Conn, []N.CountFunc{func(n int64) { c.writeCount.Add(n) }} +} + +func TestStreamReactor_WriteError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap destination with failing writer that fails after 3 writes + failingDest := &failingConn{Conn: client2, failAfter: 3} + + var capturedErr error + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, failingDest, func(err error) { + capturedErr = err + close(closeDone) + }) + + // Send multiple chunks of data to trigger the write failure + testData := make([]byte, 1024) + rand.Read(testData) + + for i := 0; i < 10; i++ { + _, err := client1.Write(testData) + if err != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Close source to trigger cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + // Verify error was propagated + assert.NotNil(t, capturedErr, "expected error to be propagated") + assert.Contains(t, capturedErr.Error(), "simulated write error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_ReadError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap source with error-returning reader + readErr := errors.New("simulated read error") + errorSrc := &errorConn{Conn: server1, readError: readErr} + + var capturedErr error + closeDone := make(chan struct{}) + reactor.Copy(ctx, errorSrc, client2, func(err error) { + capturedErr = err + close(closeDone) + }) + + select { + case <-closeDone: + // Verify error was propagated + assert.NotNil(t, capturedErr, "expected error to be propagated") + assert.Contains(t, capturedErr.Error(), "simulated read error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_Counters(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap with counting connections + countingSrc := &countingConn{Conn: server1} + countingDst := &countingConn{Conn: client2} + + closeDone := make(chan struct{}) + reactor.Copy(ctx, countingSrc, countingDst, func(err error) { + close(closeDone) + }) + + // Send data in both directions + const dataSize = 4096 + uploadData := make([]byte, dataSize) + downloadData := make([]byte, dataSize) + rand.Read(uploadData) + rand.Read(downloadData) + + // Upload: client1 -> server1 -> client2 -> server2 + _, err := client1.Write(uploadData) + require.NoError(t, err) + + received := make([]byte, dataSize) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, uploadData, received) + + // Download: server2 -> client2 -> server1 -> client1 + _, err = server2.Write(downloadData) + require.NoError(t, err) + + received2 := make([]byte, dataSize) + _, err = io.ReadFull(client1, received2) + require.NoError(t, err) + assert.Equal(t, downloadData, received2) + + // Close connections + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } + + // Verify counters (read from source, write to destination) + // Note: The countingConn tracks actual reads/writes at the connection level + assert.True(t, countingSrc.readCount.Load() >= dataSize, "source should have read at least %d bytes, got %d", dataSize, countingSrc.readCount.Load()) + assert.True(t, countingDst.writeCount.Load() >= dataSize, "destination should have written at least %d bytes, got %d", dataSize, countingDst.writeCount.Load()) +} + +func TestStreamReactor_CachedReader(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Create cached data + cachedData := make([]byte, 512) + rand.Read(cachedData) + cachedBuffer := buf.As(cachedData) + + // Wrap source with cached conn + cachedSrc := NewCachedConn(server1, cachedBuffer) + defer cachedSrc.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, cachedSrc, client2, func(err error) { + close(closeDone) + }) + + // The cached data should be sent first before any new data + // Read cached data from destination + receivedCached := make([]byte, len(cachedData)) + _, err := io.ReadFull(server2, receivedCached) + require.NoError(t, err) + assert.Equal(t, cachedData, receivedCached, "cached data should be received first") + + // Now send new data through the connection + newData := make([]byte, 256) + rand.Read(newData) + + _, err = client1.Write(newData) + require.NoError(t, err) + + receivedNew := make([]byte, len(newData)) + _, err = io.ReadFull(server2, receivedNew) + require.NoError(t, err) + assert.Equal(t, newData, receivedNew, "new data should be received after cached data") + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_LargeData(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + fdServer1 := newFDConn(t, server1) + fdClient2 := newFDConn(t, client2) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, fdServer1, fdClient2, func(err error) { + close(closeDone) + }) + + // Test with 10MB of data + const dataSize = 10 * 1024 * 1024 + uploadData := make([]byte, dataSize) + rand.Read(uploadData) + uploadHash := md5.Sum(uploadData) + + downloadData := make([]byte, dataSize) + rand.Read(downloadData) + downloadHash := md5.Sum(downloadData) + + var wg sync.WaitGroup + errChan := make(chan error, 4) + + // Upload goroutine + wg.Add(1) + go func() { + defer wg.Done() + _, err := client1.Write(uploadData) + if err != nil { + errChan <- err + } + }() + + // Upload receiver goroutine + wg.Add(1) + go func() { + defer wg.Done() + received := make([]byte, dataSize) + _, err := io.ReadFull(server2, received) + if err != nil { + errChan <- err + return + } + receivedHash := md5.Sum(received) + if receivedHash != uploadHash { + errChan <- errors.New("upload data mismatch") + } + }() + + // Download goroutine + wg.Add(1) + go func() { + defer wg.Done() + _, err := server2.Write(downloadData) + if err != nil { + errChan <- err + } + }() + + // Download receiver goroutine + wg.Add(1) + go func() { + defer wg.Done() + received := make([]byte, dataSize) + _, err := io.ReadFull(client1, received) + if err != nil { + errChan <- err + return + } + receivedHash := md5.Sum(received) + if receivedHash != downloadHash { + errChan <- errors.New("download data mismatch") + } + }() + + // Wait for completion + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case err := <-errChan: + t.Fatalf("transfer error: %v", err) + case <-done: + // Success + case <-time.After(60 * time.Second): + t.Fatal("timeout during large data transfer") + } + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_BufferedAtRegistration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Create buffered conn with pre-populated buffer + bufferedServer1 := newBufferedConn(t, server1) + defer bufferedServer1.Close() + + // Pre-populate the buffer with data + preBufferedData := []byte("pre-buffered data that should be sent immediately") + bufferedServer1.bufferMu.Lock() + bufferedServer1.buffer.Write(preBufferedData) + bufferedServer1.bufferMu.Unlock() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, bufferedServer1, client2, func(err error) { + close(closeDone) + }) + + // The pre-buffered data should be processed immediately + received := make([]byte, len(preBufferedData)) + server2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err := io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, preBufferedData, received, "pre-buffered data should be received") + + // Now send additional data + additionalData := []byte("additional data") + _, err = client1.Write(additionalData) + require.NoError(t, err) + + additionalReceived := make([]byte, len(additionalData)) + server2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = io.ReadFull(server2, additionalReceived) + require.NoError(t, err) + assert.Equal(t, additionalData, additionalReceived) + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} From c84d634f44eb59b49c872a50394837d5bfc69868 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 28 Dec 2025 04:08:40 +0800 Subject: [PATCH 09/13] Add logger parameter to stream and packet reactors --- common/bufio/packet_reactor.go | 18 ++++++++++- common/bufio/packet_reactor_test.go | 46 ++++++++++++++--------------- common/bufio/stream_reactor.go | 18 ++++++++++- common/bufio/stream_reactor_test.go | 24 +++++++-------- 4 files changed, 69 insertions(+), 37 deletions(-) diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index 3bf8180a..eba2a032 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -69,16 +70,21 @@ func CreatePacketPollable(reader N.PacketReader) (N.PacketPollable, bool) { type PacketReactor struct { ctx context.Context cancel context.CancelFunc + logger logger.Logger fdPoller *FDPoller fdPollerOnce sync.Once fdPollerErr error } -func NewPacketReactor(ctx context.Context) *PacketReactor { +func NewPacketReactor(ctx context.Context, l logger.Logger) *PacketReactor { ctx, cancel := context.WithCancel(ctx) + if l == nil { + l = logger.NOP() + } return &PacketReactor{ ctx: ctx, cancel: cancel, + logger: l, } } @@ -129,6 +135,7 @@ type reactorStream struct { } func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destination N.PacketConn, onClose N.CloseHandlerFunc) { + r.logger.Trace("packet copy: starting") ctx, cancel := context.WithCancel(ctx) conn := &reactorConnection{ ctx: ctx, @@ -216,6 +223,7 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe func (r *PacketReactor) registerStream(stream *reactorStream) { if stream.pushable != nil { + r.logger.Trace("packet stream: using pushable mode") stream.pushable.SetOnDataReady(func() { go stream.runActiveLoop(nil) }) @@ -226,18 +234,23 @@ func (r *PacketReactor) registerStream(stream *reactorStream) { } if stream.pollable == nil { + r.logger.Trace("packet stream: using legacy copy") go stream.runLegacyCopy() return } fdPoller, err := r.getFDPoller() if err != nil { + r.logger.Trace("packet stream: FD poller unavailable, using legacy copy") go stream.runLegacyCopy() return } err = fdPoller.Add(stream, stream.pollable.FD()) if err != nil { + r.logger.Trace("packet stream: failed to add to FD poller, using legacy copy") go stream.runLegacyCopy() + } else { + r.logger.Trace("packet stream: registered with FD poller") } } @@ -311,6 +324,7 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { if !s.state.CompareAndSwap(stateActive, stateIdle) { return } + s.connection.reactor.logger.Trace("packet stream: timeout, returning to idle pool") if s.pushable != nil { if s.pushable.HasPendingData() { if s.state.CompareAndSwap(stateIdle, stateActive) { @@ -325,6 +339,7 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { if !notFirstTime { err = N.ReportHandshakeFailure(s.originSource, err) } + s.connection.reactor.logger.Trace("packet stream: error occurred: ", err) s.closeWithError(err) return } @@ -399,6 +414,7 @@ func (s *reactorStream) closeWithError(err error) { func (c *reactorConnection) closeWithError(err error) { c.closeOnce.Do(func() { defer close(c.done) + c.reactor.logger.Trace("packet connection: closing with error: ", err) if c.stopReactorWatch != nil { c.stopReactorWatch() diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go index 0fd8d5ed..5152ed43 100644 --- a/common/bufio/packet_reactor_test.go +++ b/common/bufio/packet_reactor_test.go @@ -274,7 +274,7 @@ func TestBatchCopy_Pipe_DataIntegrity(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -335,7 +335,7 @@ func TestBatchCopy_Pipe_Bidirectional(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -447,7 +447,7 @@ func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) { proxyA := newFDPacketConn(t, proxyAConn, serverAddr) proxyB := newFDPacketConn(t, proxyBConn, clientAddr) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -568,7 +568,7 @@ func TestBatchCopy_LegacyChannel_DataIntegrity(t *testing.T) { proxyA := newChannelPacketConn(proxyAConn, serverAddr) proxyB := newChannelPacketConn(proxyBConn, clientAddr) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -689,7 +689,7 @@ func TestBatchCopy_MixedMode_DataIntegrity(t *testing.T) { proxyA := newFDPacketConn(t, proxyAConn, serverAddr) proxyB := newChannelPacketConn(proxyBConn, clientAddr) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -790,7 +790,7 @@ func TestBatchCopy_MultipleConnections_DataIntegrity(t *testing.T) { const numConnections = 5 - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() var wg sync.WaitGroup @@ -881,7 +881,7 @@ func TestBatchCopy_TimeoutAndResume_DataIntegrity(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -945,7 +945,7 @@ func TestBatchCopy_CloseWhileTransferring(t *testing.T) { pipeA := newTestPacketPipe(addr1) pipeB := newTestPacketPipe(addr2) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) copyDone := make(chan struct{}) go func() { @@ -995,7 +995,7 @@ func TestBatchCopy_HighThroughput(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -1080,7 +1080,7 @@ func TestBatchCopy_LegacyFallback_DataIntegrity(t *testing.T) { proxyA := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyAConn), targetAddr: serverAddr} proxyB := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyBConn), targetAddr: clientAddr} - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -1195,7 +1195,7 @@ func TestBatchCopy_ReactorClose(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) copyDone := make(chan struct{}) go func() { @@ -1243,7 +1243,7 @@ func TestBatchCopy_SmallPackets(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -1297,7 +1297,7 @@ func TestBatchCopy_VaryingPacketSizes(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() go func() { @@ -1365,7 +1365,7 @@ func TestBatchCopy_OnCloseCallback(t *testing.T) { pipeA := newTestPacketPipe(addr1) pipeB := newTestPacketPipe(addr2) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() callbackCalled := make(chan error, 1) @@ -1409,7 +1409,7 @@ func TestBatchCopy_SourceClose(t *testing.T) { pipeA := newTestPacketPipe(addr1) pipeB := newTestPacketPipe(addr2) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() var capturedErr error @@ -1667,7 +1667,7 @@ func TestPacketReactor_Pushable_Basic(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() copyDone := make(chan struct{}) @@ -1750,7 +1750,7 @@ func TestPacketReactor_Pushable_HasPendingData(t *testing.T) { pipeA.send(data, addr2) } - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() copyDone := make(chan struct{}) @@ -1799,7 +1799,7 @@ func TestPacketReactor_Pushable_TimeoutResume(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() copyDone := make(chan struct{}) @@ -1874,7 +1874,7 @@ func TestPacketReactor_WriteError(t *testing.T) { // Wrap destination with failing writer that fails after 3 packets failingDest := &failingPacketWriter{PacketConn: pipeB, failAfter: 3} - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() var capturedErr error @@ -1929,7 +1929,7 @@ func TestPacketReactor_ReadError(t *testing.T) { readErr := errors.New("simulated packet read error") errorSrc := &errorPacketReader{PacketConn: pipeA, readError: readErr} - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() var capturedErr error @@ -1981,7 +1981,7 @@ func TestPacketReactor_StateMachine_ConcurrentWakeup(t *testing.T) { proxyA := newFDPacketConn(t, proxyAConn, serverAddr) proxyB := newFDPacketConn(t, proxyBConn, clientAddr) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() copyDone := make(chan struct{}) @@ -2034,7 +2034,7 @@ func TestPacketReactor_StateMachine_CloseWhileActive(t *testing.T) { pipeA := newTestPacketPipe(addr1) pipeB := newTestPacketPipe(addr2) - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) copyDone := make(chan struct{}) go func() { @@ -2088,7 +2088,7 @@ func TestPacketReactor_Counters(t *testing.T) { defer pipeA.Close() defer pipeB.Close() - copier := NewPacketReactor(context.Background()) + copier := NewPacketReactor(context.Background(), nil) defer copier.Close() copyDone := make(chan struct{}) diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index d30f7f38..0924a3f3 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" ) @@ -40,16 +41,21 @@ func CreateStreamPollable(reader io.Reader) (N.StreamPollable, bool) { type StreamReactor struct { ctx context.Context cancel context.CancelFunc + logger logger.Logger fdPoller *FDPoller fdPollerOnce sync.Once fdPollerErr error } -func NewStreamReactor(ctx context.Context) *StreamReactor { +func NewStreamReactor(ctx context.Context, l logger.Logger) *StreamReactor { ctx, cancel := context.WithCancel(ctx) + if l == nil { + l = logger.NOP() + } return &StreamReactor{ ctx: ctx, cancel: cancel, + logger: l, } } @@ -104,10 +110,12 @@ type streamDirection struct { func (r *StreamReactor) Copy(ctx context.Context, source net.Conn, destination net.Conn, onClose N.CloseHandlerFunc) { // Try splice first (zero-copy optimization) if r.trySplice(ctx, source, destination, onClose) { + r.logger.Trace("stream copy: using splice for zero-copy") return } // Fall back to reactor mode + r.logger.Trace("stream copy: using reactor mode") ctx, cancel := context.WithCancel(ctx) conn := &streamConnection{ ctx: ctx, @@ -209,6 +217,7 @@ func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reade func (r *StreamReactor) registerDirection(direction *streamDirection) { // Check if there's buffered data that needs processing first if direction.pollable != nil && direction.pollable.Buffered() > 0 { + r.logger.Trace("stream direction: has buffered data, starting active loop") go direction.runActiveLoop() return } @@ -219,12 +228,14 @@ func (r *StreamReactor) registerDirection(direction *streamDirection) { if err == nil { err = fdPoller.Add(direction, direction.pollable.FD()) if err == nil { + r.logger.Trace("stream direction: registered with FD poller") return } } } // Fall back to legacy goroutine copy + r.logger.Trace("stream direction: using legacy copy") go direction.runLegacyCopy() } @@ -271,6 +282,7 @@ func (d *streamDirection) runActiveLoop() { setter.SetReadDeadline(time.Time{}) } if d.state.CompareAndSwap(stateActive, stateIdle) { + d.connection.reactor.logger.Trace("stream direction: timeout, returning to idle pool") d.returnToPool() } return @@ -349,6 +361,7 @@ func (d *streamDirection) runLegacyCopy() { func (d *streamDirection) handleEOFOrError(err error) { if err == nil || err == io.EOF { // Graceful EOF: close write direction only (half-close) + d.connection.reactor.logger.Trace("stream direction: graceful EOF, half-closing") d.state.Store(stateClosed) // Try half-close on destination @@ -367,6 +380,7 @@ func (d *streamDirection) handleEOFOrError(err error) { } // Error: close entire connection + d.connection.reactor.logger.Trace("stream direction: error occurred: ", err) d.closeWithError(err) } @@ -381,6 +395,7 @@ func (c *streamConnection) checkBothClosed() { if uploadClosed && downloadClosed { c.closeOnce.Do(func() { defer close(c.done) + c.reactor.logger.Trace("stream connection: both directions closed gracefully") if c.stopReactorWatch != nil { c.stopReactorWatch() @@ -402,6 +417,7 @@ func (c *streamConnection) checkBothClosed() { func (c *streamConnection) closeWithError(err error) { c.closeOnce.Do(func() { defer close(c.done) + c.reactor.logger.Trace("stream connection: closing with error: ", err) if c.stopReactorWatch != nil { c.stopReactorWatch() diff --git a/common/bufio/stream_reactor_test.go b/common/bufio/stream_reactor_test.go index cfe481c9..f19b652c 100644 --- a/common/bufio/stream_reactor_test.go +++ b/common/bufio/stream_reactor_test.go @@ -103,7 +103,7 @@ func TestStreamReactor_Basic(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() // Create a pair of connected TCP connections @@ -194,7 +194,7 @@ func TestStreamReactor_FDNotifier(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() // Create TCP connection pairs @@ -241,7 +241,7 @@ func TestStreamReactor_BufferedReader(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -287,7 +287,7 @@ func TestStreamReactor_HalfClose(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -345,7 +345,7 @@ func TestStreamReactor_MultipleConnections(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() const numConnections = 10 @@ -402,7 +402,7 @@ func TestStreamReactor_ReactorClose(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) server1, client1 := createTCPPair(t) defer server1.Close() @@ -520,7 +520,7 @@ func TestStreamReactor_WriteError(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -571,7 +571,7 @@ func TestStreamReactor_ReadError(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -607,7 +607,7 @@ func TestStreamReactor_Counters(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -672,7 +672,7 @@ func TestStreamReactor_CachedReader(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -731,7 +731,7 @@ func TestStreamReactor_LargeData(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) @@ -846,7 +846,7 @@ func TestStreamReactor_BufferedAtRegistration(t *testing.T) { t.Parallel() ctx := context.Background() - reactor := NewStreamReactor(ctx) + reactor := NewStreamReactor(ctx, nil) defer reactor.Close() server1, client1 := createTCPPair(t) From 93986eff9145e7a9a6016d1cdf62325607e31320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 28 Dec 2025 14:54:49 +0800 Subject: [PATCH 10/13] Fix kqueue FDPoller GC crash on Darwin --- common/bufio/fd_poller_darwin.go | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/common/bufio/fd_poller_darwin.go b/common/bufio/fd_poller_darwin.go index 3f97e903..e781dde9 100644 --- a/common/bufio/fd_poller_darwin.go +++ b/common/bufio/fd_poller_darwin.go @@ -94,21 +94,22 @@ func (p *FDPoller) Add(handler FDHandler, fd int) error { p.registrationCounter++ registrationID := p.registrationCounter + entry := &fdDemuxEntry{ + fd: fd, + registrationID: registrationID, + handler: handler, + } + _, err := unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ Ident: uint64(fd), Filter: unix.EVFILT_READ, Flags: unix.EV_ADD | unix.EV_ONESHOT, - Udata: (*byte)(unsafe.Pointer(uintptr(registrationID))), + Udata: (*byte)(unsafe.Pointer(entry)), }}, nil, nil) if err != nil { return err } - entry := &fdDemuxEntry{ - fd: fd, - registrationID: registrationID, - handler: handler, - } p.entries[fd] = entry p.registrationToFD[registrationID] = fd @@ -208,26 +209,20 @@ func (p *FDPoller) run() { continue } - registrationID := uint64(uintptr(unsafe.Pointer(event.Udata))) + eventEntry := (*fdDemuxEntry)(unsafe.Pointer(event.Udata)) p.mutex.Lock() - mappedFD, ok := p.registrationToFD[registrationID] - if !ok || mappedFD != fd { - p.mutex.Unlock() - continue - } - - entry := p.entries[fd] - if entry == nil || entry.registrationID != registrationID { + currentEntry := p.entries[fd] + if currentEntry != eventEntry { p.mutex.Unlock() continue } - delete(p.registrationToFD, registrationID) + delete(p.registrationToFD, currentEntry.registrationID) delete(p.entries, fd) p.mutex.Unlock() - go entry.handler.HandleFDEvent() + go currentEntry.handler.HandleFDEvent() } p.mutex.Lock() From fddde903aa5b8df234e6ec837491efe3fa61b9cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 28 Dec 2025 14:57:06 +0800 Subject: [PATCH 11/13] Add deadline fallback for readers requiring additional handling --- common/bufio/packet_reactor.go | 23 +++++++++++++++++++++++ common/bufio/stream_reactor.go | 6 ++++++ 2 files changed, 29 insertions(+) diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index eba2a032..d7f708fe 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -24,6 +24,23 @@ const ( stateClosed int32 = 2 ) +type withoutReadDeadline interface { + NeedAdditionalReadDeadline() bool +} + +func needAdditionalReadDeadline(rawReader any) bool { + if deadlineReader, loaded := rawReader.(withoutReadDeadline); loaded { + return deadlineReader.NeedAdditionalReadDeadline() + } + if upstream, hasUpstream := rawReader.(N.WithUpstreamReader); hasUpstream { + return needAdditionalReadDeadline(upstream.UpstreamReader()) + } + if upstream, hasUpstream := rawReader.(common.WithUpstream); hasUpstream { + return needAdditionalReadDeadline(upstream.Upstream()) + } + return false +} + func CreatePacketPushable(reader N.PacketReader) (N.PacketPushable, bool) { if pushable, ok := reader.(N.PacketPushable); ok { return pushable, true @@ -222,6 +239,12 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe } func (r *PacketReactor) registerStream(stream *reactorStream) { + if needAdditionalReadDeadline(stream.source) { + r.logger.Trace("packet stream: needs additional deadline handling, using legacy copy") + go stream.runLegacyCopy() + return + } + if stream.pushable != nil { r.logger.Trace("packet stream: using pushable mode") stream.pushable.SetOnDataReady(func() { diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index 0924a3f3..94a6c314 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -215,6 +215,12 @@ func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reade } func (r *StreamReactor) registerDirection(direction *streamDirection) { + if needAdditionalReadDeadline(direction.source) { + r.logger.Trace("stream direction: needs additional deadline handling, using legacy copy") + go direction.runLegacyCopy() + return + } + // Check if there's buffered data that needs processing first if direction.pollable != nil && direction.pollable.Buffered() > 0 { r.logger.Trace("stream direction: has buffered data, starting active loop") From 7bf8a918e49a7293b50ee58819c59017d9ba1d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 29 Dec 2025 13:30:59 +0800 Subject: [PATCH 12/13] save --- common/bufio/packet_reactor.go | 54 +++++++++---- common/bufio/stream_reactor.go | 44 ++++++++--- common/udpnat2/conn.go | 139 +++++++++++++++++++++++++-------- common/udpnat2/service.go | 10 ++- 4 files changed, 184 insertions(+), 63 deletions(-) diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go index d7f708fe..3347d2c2 100644 --- a/common/bufio/packet_reactor.go +++ b/common/bufio/packet_reactor.go @@ -15,7 +15,7 @@ import ( ) const ( - batchReadTimeout = 250 * time.Millisecond + batchReadTimeout = 5 *time.Second ) const ( @@ -141,12 +141,14 @@ type reactorStream struct { destination N.PacketWriter originSource N.PacketReader - pushable N.PacketPushable - pollable N.PacketPollable - options N.ReadWaitOptions - readWaiter N.PacketReadWaiter - readCounters []N.CountFunc - writeCounters []N.CountFunc + pushable N.PacketPushable + pushableCallback func() + pollable N.PacketPollable + deadlineSetter deadlineSetter + options N.ReadWaitOptions + readWaiter N.PacketReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc state atomic.Int32 } @@ -245,11 +247,27 @@ func (r *PacketReactor) registerStream(stream *reactorStream) { return } + // Check if deadline setter is available and functional + if setter, ok := stream.source.(deadlineSetter); ok { + err := setter.SetReadDeadline(time.Time{}) + if err != nil { + r.logger.Trace("packet stream: SetReadDeadline not supported, using legacy copy") + go stream.runLegacyCopy() + return + } + stream.deadlineSetter = setter + } else { + r.logger.Trace("packet stream: no deadline setter, using legacy copy") + go stream.runLegacyCopy() + return + } + if stream.pushable != nil { r.logger.Trace("packet stream: using pushable mode") - stream.pushable.SetOnDataReady(func() { + stream.pushableCallback = func() { go stream.runActiveLoop(nil) - }) + } + stream.pushable.SetOnDataReady(stream.pushableCallback) if stream.pushable.HasPendingData() { go stream.runActiveLoop(nil) } @@ -293,6 +311,10 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { return } + if s.pushable != nil { + s.pushable.SetOnDataReady(nil) + } + notFirstTime := false if firstPacket != nil { @@ -309,8 +331,12 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { return } - if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { - setter.SetReadDeadline(time.Now().Add(batchReadTimeout)) + deadlineErr := s.deadlineSetter.SetReadDeadline(time.Now().Add(batchReadTimeout)) + if deadlineErr != nil { + s.connection.reactor.logger.Trace("packet stream: SetReadDeadline failed, switching to legacy copy") + s.state.Store(stateIdle) + go s.runLegacyCopy() + return } var ( @@ -341,16 +367,16 @@ func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { if err != nil { if E.IsTimeout(err) { - if setter, ok := s.source.(interface{ SetReadDeadline(time.Time) error }); ok { - setter.SetReadDeadline(time.Time{}) - } + s.deadlineSetter.SetReadDeadline(time.Time{}) if !s.state.CompareAndSwap(stateActive, stateIdle) { return } s.connection.reactor.logger.Trace("packet stream: timeout, returning to idle pool") if s.pushable != nil { + s.pushable.SetOnDataReady(s.pushableCallback) if s.pushable.HasPendingData() { if s.state.CompareAndSwap(stateIdle, stateActive) { + s.pushable.SetOnDataReady(nil) continue } } diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go index 94a6c314..da61e684 100644 --- a/common/bufio/stream_reactor.go +++ b/common/bufio/stream_reactor.go @@ -16,7 +16,7 @@ import ( ) const ( - streamBatchReadTimeout = 250 * time.Millisecond + streamBatchReadTimeout = 5*time.Second ) func CreateStreamPollable(reader io.Reader) (N.StreamPollable, bool) { @@ -88,6 +88,10 @@ type streamConnection struct { err error } +type deadlineSetter interface { + SetReadDeadline(time.Time) error +} + type streamDirection struct { connection *streamConnection @@ -95,11 +99,12 @@ type streamDirection struct { destination io.Writer originSource net.Conn - pollable N.StreamPollable - options N.ReadWaitOptions - readWaiter N.ReadWaiter - readCounters []N.CountFunc - writeCounters []N.CountFunc + pollable N.StreamPollable + deadlineSetter deadlineSetter + options N.ReadWaitOptions + readWaiter N.ReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc isUpload bool state atomic.Int32 @@ -221,6 +226,21 @@ func (r *StreamReactor) registerDirection(direction *streamDirection) { return } + // Check if deadline setter is available and functional + if setter, ok := direction.originSource.(deadlineSetter); ok { + err := setter.SetReadDeadline(time.Time{}) + if err != nil { + r.logger.Trace("stream direction: SetReadDeadline not supported, using legacy copy") + go direction.runLegacyCopy() + return + } + direction.deadlineSetter = setter + } else { + r.logger.Trace("stream direction: no deadline setter, using legacy copy") + go direction.runLegacyCopy() + return + } + // Check if there's buffered data that needs processing first if direction.pollable != nil && direction.pollable.Buffered() > 0 { r.logger.Trace("stream direction: has buffered data, starting active loop") @@ -261,8 +281,12 @@ func (d *streamDirection) runActiveLoop() { } // Set batch read timeout - if setter, ok := d.originSource.(interface{ SetReadDeadline(time.Time) error }); ok { - setter.SetReadDeadline(time.Now().Add(streamBatchReadTimeout)) + deadlineErr := d.deadlineSetter.SetReadDeadline(time.Now().Add(streamBatchReadTimeout)) + if deadlineErr != nil { + d.connection.reactor.logger.Trace("stream direction: SetReadDeadline failed, switching to legacy copy") + d.state.Store(stateIdle) + go d.runLegacyCopy() + return } var ( @@ -284,9 +308,7 @@ func (d *streamDirection) runActiveLoop() { if err != nil { if E.IsTimeout(err) { // Timeout: check buffer and return to pool - if setter, ok := d.originSource.(interface{ SetReadDeadline(time.Time) error }); ok { - setter.SetReadDeadline(time.Time{}) - } + d.deadlineSetter.SetReadDeadline(time.Time{}) if d.state.CompareAndSwap(stateActive, stateIdle) { d.connection.reactor.logger.Trace("stream direction: timeout, returning to idle pool") d.returnToPool() diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index cca511d3..57260f64 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -40,31 +40,45 @@ type natConn struct { queueMutex sync.Mutex onDataReady func() + deadlineMutex sync.Mutex + deadlineTimer *time.Timer + deadlineChan chan struct{} + dataSignal chan struct{} + closeOnce sync.Once doneChan chan struct{} } -func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { - select { - case <-c.doneChan: - return M.Socksaddr{}, io.ErrClosedPipe - default: - } +func (c *natConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + for { + select { + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + default: + } - c.queueMutex.Lock() - if len(c.dataQueue) == 0 { + c.queueMutex.Lock() + if len(c.dataQueue) > 0 { + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + } c.queueMutex.Unlock() - return M.Socksaddr{}, os.ErrDeadlineExceeded - } - packet := c.dataQueue[0] - c.dataQueue = c.dataQueue[1:] - c.queueMutex.Unlock() - _, err = buffer.ReadOnceFrom(packet.Buffer) - destination := packet.Destination - packet.Buffer.Release() - N.PutPacketBuffer(packet) - return destination, err + select { + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + case <-c.waitDeadline(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-c.dataSignal: + continue + } + } } func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { @@ -79,25 +93,34 @@ func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool } func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case <-c.doneChan: - return nil, M.Socksaddr{}, io.ErrClosedPipe - default: - } + for { + select { + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + default: + } - c.queueMutex.Lock() - if len(c.dataQueue) == 0 { + c.queueMutex.Lock() + if len(c.dataQueue) > 0 { + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + N.PutPacketBuffer(packet) + return + } c.queueMutex.Unlock() - return nil, M.Socksaddr{}, os.ErrDeadlineExceeded - } - packet := c.dataQueue[0] - c.dataQueue = c.dataQueue[1:] - c.queueMutex.Unlock() - buffer = c.readWaitOptions.Copy(packet.Buffer) - destination = packet.Destination - N.PutPacketBuffer(packet) - return + select { + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.waitDeadline(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + case <-c.dataSignal: + continue + } + } } func (c *natConn) SetHandler(handler N.UDPHandlerEx) { @@ -141,6 +164,11 @@ func (c *natConn) PushPacket(packet *N.PacketBuffer) { callback := c.onDataReady c.queueMutex.Unlock() + select { + case c.dataSignal <- struct{}{}: + default: + } + if callback != nil { callback() } @@ -187,9 +215,52 @@ func (c *natConn) SetDeadline(t time.Time) error { } func (c *natConn) SetReadDeadline(t time.Time) error { + c.deadlineMutex.Lock() + defer c.deadlineMutex.Unlock() + + if c.deadlineTimer != nil && !c.deadlineTimer.Stop() { + <-c.deadlineChan + } + c.deadlineTimer = nil + + if t.IsZero() { + if isClosedChan(c.deadlineChan) { + c.deadlineChan = make(chan struct{}) + } + return nil + } + + if duration := time.Until(t); duration > 0 { + if isClosedChan(c.deadlineChan) { + c.deadlineChan = make(chan struct{}) + } + c.deadlineTimer = time.AfterFunc(duration, func() { + close(c.deadlineChan) + }) + return nil + } + + if !isClosedChan(c.deadlineChan) { + close(c.deadlineChan) + } return nil } +func (c *natConn) waitDeadline() chan struct{} { + c.deadlineMutex.Lock() + defer c.deadlineMutex.Unlock() + return c.deadlineChan +} + +func isClosedChan(channel <-chan struct{}) bool { + select { + case <-channel: + return true + default: + return false + } +} + func (c *natConn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 3cf7392d..4370651f 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -56,10 +56,12 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati return nil, false } newConn := &natConn{ - cache: s.cache, - writer: writer, - localAddr: source, - doneChan: make(chan struct{}), + cache: s.cache, + writer: writer, + localAddr: source, + deadlineChan: make(chan struct{}), + dataSignal: make(chan struct{}, 1), + doneChan: make(chan struct{}), } go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) return newConn, true From 59ed1c76437e2ccad7e0462a65ea79450d0e93c7 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Wed, 31 Dec 2025 15:05:19 +0000 Subject: [PATCH 13/13] [dependencies] Update github-actions --- .github/workflows/lint.yml | 8 ++++---- .github/workflows/test.yml | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 08a8b8d4..06a7cadc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,21 +20,21 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.25 - name: Cache go module - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: | ~/go/pkg/mod key: go-${{ hashFiles('**/go.sum') }} - name: golangci-lint - uses: golangci/golangci-lint-action@v8 + uses: golangci/golangci-lint-action@v9 with: version: v2.4.0 args: --timeout=30m diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3afe9684..faf80c28 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 - name: Build @@ -35,11 +35,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.20 continue-on-error: true @@ -51,11 +51,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.21 continue-on-error: true @@ -67,11 +67,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.22 continue-on-error: true @@ -83,11 +83,11 @@ jobs: runs-on: windows-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 continue-on-error: true @@ -99,11 +99,11 @@ jobs: runs-on: macos-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 continue-on-error: true