diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 08a8b8d4..06a7cadc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,21 +20,21 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.25 - name: Cache go module - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: | ~/go/pkg/mod key: go-${{ hashFiles('**/go.sum') }} - name: golangci-lint - uses: golangci/golangci-lint-action@v8 + uses: golangci/golangci-lint-action@v9 with: version: v2.4.0 args: --timeout=30m diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3afe9684..faf80c28 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 - name: Build @@ -35,11 +35,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.20 continue-on-error: true @@ -51,11 +51,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.21 continue-on-error: true @@ -67,11 +67,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ~1.22 continue-on-error: true @@ -83,11 +83,11 @@ jobs: runs-on: windows-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 continue-on-error: true @@ -99,11 +99,11 @@ jobs: runs-on: macos-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ^1.23 continue-on-error: true diff --git a/common/bufio/fd_demux_windows_test.go b/common/bufio/fd_demux_windows_test.go new file mode 100644 index 00000000..4f7ccce4 --- /dev/null +++ b/common/bufio/fd_demux_windows_test.go @@ -0,0 +1,305 @@ +//go:build windows + +package bufio + +import ( + "context" + "net" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func getSocketFD(t *testing.T, conn net.PacketConn) int { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return fd +} + +func TestFDDemultiplexer_Create(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) +} + +func TestFDDemultiplexer_CreateMultiple(t *testing.T) { + t.Parallel() + + demux1, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux1.Close() + + demux2, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux2.Close() +} + +func TestFDDemultiplexer_AddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) +} + +func TestFDDemultiplexer_RapidAddRemove(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const iterations = 50 + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + demux.Remove(fd) + conn.Close() + } +} + +func TestFDDemultiplexer_ConcurrentAccess(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numGoroutines = 10 + const iterations = 20 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + continue + } + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + if err == nil { + demux.Remove(fd) + } + conn.Close() + } + }() + } + + wg.Wait() +} + +func TestFDDemultiplexer_ReceiveEvent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + demux, err := NewFDPoller(ctx) + require.NoError(t, err) + defer demux.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + + triggered := make(chan struct{}, 1) + stream := &reactorStream{ + state: atomic.Int32{}, + } + stream.connection = &reactorConnection{ + upload: stream, + download: stream, + done: make(chan struct{}), + } + + originalRunActiveLoop := stream.runActiveLoop + _ = originalRunActiveLoop + + err = demux.Add(stream, fd) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + time.Sleep(200 * time.Millisecond) + + select { + case <-triggered: + default: + } + + demux.Remove(fd) +} + +func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked - possible deadlock") + } +} + +func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux.Close() + + demux.Remove(99999) +} + +func TestFDDemultiplexer_AddAfterClose(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + + err = demux.Close() + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.Error(t, err) +} + +func TestFDDemultiplexer_MultipleSocketsSimultaneous(t *testing.T) { + t.Parallel() + + demux, err := NewFDPoller(context.Background()) + require.NoError(t, err) + defer demux.Close() + + const numSockets = 5 + conns := make([]net.PacketConn, numSockets) + fds := make([]int, numSockets) + + for i := 0; i < numSockets; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + fd := getSocketFD(t, conn) + fds[i] = fd + + stream := &reactorStream{} + err = demux.Add(stream, fd) + require.NoError(t, err) + } + + for i := 0; i < numSockets; i++ { + demux.Remove(fds[i]) + } +} + +func TestFDDemultiplexer_ContextCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + demux, err := NewFDPoller(ctx) + require.NoError(t, err) + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + fd := getSocketFD(t, conn) + stream := &reactorStream{} + + err = demux.Add(stream, fd) + require.NoError(t, err) + + cancel() + + time.Sleep(100 * time.Millisecond) + + done := make(chan struct{}) + go func() { + demux.Close() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Close blocked after context cancellation") + } +} diff --git a/common/bufio/fd_handler.go b/common/bufio/fd_handler.go new file mode 100644 index 00000000..4c41677f --- /dev/null +++ b/common/bufio/fd_handler.go @@ -0,0 +1,9 @@ +package bufio + +// FDHandler is the interface for handling FD ready events +// Implemented by both reactorStream (UDP) and streamDirection (TCP) +type FDHandler interface { + // HandleFDEvent is called when the FD has data ready to read + // The handler should start processing data in a new goroutine + HandleFDEvent() +} diff --git a/common/bufio/fd_poller_darwin.go b/common/bufio/fd_poller_darwin.go new file mode 100644 index 00000000..e781dde9 --- /dev/null +++ b/common/bufio/fd_poller_darwin.go @@ -0,0 +1,236 @@ +//go:build darwin + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + registrationID uint64 + handler FDHandler +} + +type FDPoller struct { + ctx context.Context + cancel context.CancelFunc + kqueueFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + registrationToFD map[uint64]int + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDPoller(ctx context.Context) (*FDPoller, error) { + kqueueFD, err := unix.Kqueue() + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe(pipeFDs[:]) + if err != nil { + unix.Close(kqueueFD) + return nil, err + } + + err = unix.SetNonblock(pipeFDs[0], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + err = unix.SetNonblock(pipeFDs[1], true) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + _, err = unix.Kevent(kqueueFD, []unix.Kevent_t{{ + Ident: uint64(pipeFDs[0]), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD, + }}, nil, nil) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(kqueueFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + poller := &FDPoller{ + ctx: ctx, + cancel: cancel, + kqueueFD: kqueueFD, + entries: make(map[int]*fdDemuxEntry), + registrationToFD: make(map[uint64]int), + pipeFDs: pipeFDs, + } + return poller, nil +} + +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + if p.closed.Load() { + return unix.EINVAL + } + + p.registrationCounter++ + registrationID := p.registrationCounter + + entry := &fdDemuxEntry{ + fd: fd, + registrationID: registrationID, + handler: handler, + } + + _, err := unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD | unix.EV_ONESHOT, + Udata: (*byte)(unsafe.Pointer(entry)), + }}, nil, nil) + if err != nil { + return err + } + + p.entries[fd] = entry + p.registrationToFD[registrationID] = fd + + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() + } + + return nil +} + +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() + + entry, ok := p.entries[fd] + if !ok { + return + } + + unix.Kevent(p.kqueueFD, []unix.Kevent_t{{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_DELETE, + }}, nil, nil) + delete(p.registrationToFD, entry.registrationID) + delete(p.entries, fd) +} + +func (p *FDPoller) wakeup() { + unix.Write(p.pipeFDs[1], []byte{0}) +} + +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() + + p.cancel() + p.wakeup() + p.wg.Wait() + + p.mutex.Lock() + defer p.mutex.Unlock() + + if p.kqueueFD != -1 { + unix.Close(p.kqueueFD) + p.kqueueFD = -1 + } + if p.pipeFDs[0] != -1 { + unix.Close(p.pipeFDs[0]) + unix.Close(p.pipeFDs[1]) + p.pipeFDs[0] = -1 + p.pipeFDs[1] = -1 + } + return nil +} + +func (p *FDPoller) run() { + defer p.wg.Done() + + events := make([]unix.Kevent_t, 64) + var buffer [1]byte + + for { + select { + case <-p.ctx.Done(): + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + default: + } + + n, err := unix.Kevent(p.kqueueFD, nil, events, nil) + if err != nil { + if err == unix.EINTR { + continue + } + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + fd := int(event.Ident) + + if fd == p.pipeFDs[0] { + unix.Read(p.pipeFDs[0], buffer[:]) + continue + } + + if event.Flags&unix.EV_ERROR != 0 { + continue + } + + eventEntry := (*fdDemuxEntry)(unsafe.Pointer(event.Udata)) + + p.mutex.Lock() + currentEntry := p.entries[fd] + if currentEntry != eventEntry { + p.mutex.Unlock() + continue + } + + delete(p.registrationToFD, currentEntry.registrationID) + delete(p.entries, fd) + p.mutex.Unlock() + + go currentEntry.handler.HandleFDEvent() + } + + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() + return + } + p.mutex.Unlock() + } +} diff --git a/common/bufio/fd_poller_linux.go b/common/bufio/fd_poller_linux.go new file mode 100644 index 00000000..a2e245b9 --- /dev/null +++ b/common/bufio/fd_poller_linux.go @@ -0,0 +1,217 @@ +//go:build linux + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/unix" +) + +type fdDemuxEntry struct { + fd int + registrationID uint64 + handler FDHandler +} + +type FDPoller struct { + ctx context.Context + cancel context.CancelFunc + epollFD int + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + registrationToFD map[uint64]int + running bool + closed atomic.Bool + wg sync.WaitGroup + pipeFDs [2]int +} + +func NewFDPoller(ctx context.Context) (*FDPoller, error) { + epollFD, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + + var pipeFDs [2]int + err = unix.Pipe2(pipeFDs[:], unix.O_NONBLOCK|unix.O_CLOEXEC) + if err != nil { + unix.Close(epollFD) + return nil, err + } + + pipeEvent := &unix.EpollEvent{Events: unix.EPOLLIN} + *(*uint64)(unsafe.Pointer(&pipeEvent.Fd)) = 0 + err = unix.EpollCtl(epollFD, unix.EPOLL_CTL_ADD, pipeFDs[0], pipeEvent) + if err != nil { + unix.Close(pipeFDs[0]) + unix.Close(pipeFDs[1]) + unix.Close(epollFD) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + poller := &FDPoller{ + ctx: ctx, + cancel: cancel, + epollFD: epollFD, + entries: make(map[int]*fdDemuxEntry), + registrationToFD: make(map[uint64]int), + pipeFDs: pipeFDs, + } + return poller, nil +} + +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + if p.closed.Load() { + return unix.EINVAL + } + + p.registrationCounter++ + registrationID := p.registrationCounter + + event := &unix.EpollEvent{Events: unix.EPOLLIN | unix.EPOLLRDHUP} + *(*uint64)(unsafe.Pointer(&event.Fd)) = registrationID + + err := unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_ADD, fd, event) + if err != nil { + return err + } + + entry := &fdDemuxEntry{ + fd: fd, + registrationID: registrationID, + handler: handler, + } + p.entries[fd] = entry + p.registrationToFD[registrationID] = fd + + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() + } + + return nil +} + +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() + + entry, ok := p.entries[fd] + if !ok { + return + } + + unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(p.registrationToFD, entry.registrationID) + delete(p.entries, fd) +} + +func (p *FDPoller) wakeup() { + unix.Write(p.pipeFDs[1], []byte{0}) +} + +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() + + p.cancel() + p.wakeup() + p.wg.Wait() + + p.mutex.Lock() + defer p.mutex.Unlock() + + if p.epollFD != -1 { + unix.Close(p.epollFD) + p.epollFD = -1 + } + if p.pipeFDs[0] != -1 { + unix.Close(p.pipeFDs[0]) + unix.Close(p.pipeFDs[1]) + p.pipeFDs[0] = -1 + p.pipeFDs[1] = -1 + } + return nil +} + +func (p *FDPoller) run() { + defer p.wg.Done() + + events := make([]unix.EpollEvent, 64) + var buffer [1]byte + + for { + select { + case <-p.ctx.Done(): + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + default: + } + + n, err := unix.EpollWait(p.epollFD, events, -1) + if err != nil { + if err == unix.EINTR { + continue + } + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + } + + for i := 0; i < n; i++ { + event := events[i] + registrationID := *(*uint64)(unsafe.Pointer(&event.Fd)) + + if registrationID == 0 { + unix.Read(p.pipeFDs[0], buffer[:]) + continue + } + + if event.Events&(unix.EPOLLIN|unix.EPOLLRDHUP|unix.EPOLLHUP|unix.EPOLLERR) == 0 { + continue + } + + p.mutex.Lock() + fd, ok := p.registrationToFD[registrationID] + if !ok { + p.mutex.Unlock() + continue + } + + entry := p.entries[fd] + if entry == nil || entry.registrationID != registrationID { + p.mutex.Unlock() + continue + } + + unix.EpollCtl(p.epollFD, unix.EPOLL_CTL_DEL, fd, nil) + delete(p.registrationToFD, registrationID) + delete(p.entries, fd) + p.mutex.Unlock() + + go entry.handler.HandleFDEvent() + } + + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() + return + } + p.mutex.Unlock() + } +} diff --git a/common/bufio/fd_poller_stub.go b/common/bufio/fd_poller_stub.go new file mode 100644 index 00000000..11921482 --- /dev/null +++ b/common/bufio/fd_poller_stub.go @@ -0,0 +1,25 @@ +//go:build !linux && !darwin && !windows + +package bufio + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" +) + +type FDPoller struct{} + +func NewFDPoller(ctx context.Context) (*FDPoller, error) { + return nil, E.New("FDPoller not supported on this platform") +} + +func (p *FDPoller) Add(handler FDHandler, fd int) error { + return E.New("FDPoller not supported on this platform") +} + +func (p *FDPoller) Remove(fd int) {} + +func (p *FDPoller) Close() error { + return nil +} diff --git a/common/bufio/fd_poller_windows.go b/common/bufio/fd_poller_windows.go new file mode 100644 index 00000000..a1cb368a --- /dev/null +++ b/common/bufio/fd_poller_windows.go @@ -0,0 +1,267 @@ +//go:build windows + +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" + + "github.com/sagernet/sing/common/wepoll" + + "golang.org/x/sys/windows" +) + +type fdDemuxEntry struct { + ioStatusBlock windows.IO_STATUS_BLOCK + pollInfo wepoll.AFDPollInfo + handler FDHandler + fd int + handle windows.Handle + baseHandle windows.Handle + registrationID uint64 + cancelled bool + unpinned bool + pinner wepoll.Pinner +} + +type FDPoller struct { + ctx context.Context + cancel context.CancelFunc + iocp windows.Handle + afd *wepoll.AFD + mutex sync.Mutex + entries map[int]*fdDemuxEntry + registrationCounter uint64 + running bool + closed atomic.Bool + wg sync.WaitGroup +} + +func NewFDPoller(ctx context.Context) (*FDPoller, error) { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return nil, err + } + + afd, err := wepoll.NewAFD(iocp, "Go") + if err != nil { + windows.CloseHandle(iocp) + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + poller := &FDPoller{ + ctx: ctx, + cancel: cancel, + iocp: iocp, + afd: afd, + entries: make(map[int]*fdDemuxEntry), + } + return poller, nil +} + +func (p *FDPoller) Add(handler FDHandler, fd int) error { + p.mutex.Lock() + defer p.mutex.Unlock() + + if p.closed.Load() { + return windows.ERROR_INVALID_HANDLE + } + + handle := windows.Handle(fd) + baseHandle, err := wepoll.GetBaseSocket(handle) + if err != nil { + return err + } + + p.registrationCounter++ + registrationID := p.registrationCounter + + entry := &fdDemuxEntry{ + handler: handler, + fd: fd, + handle: handle, + baseHandle: baseHandle, + registrationID: registrationID, + } + + entry.pinner.Pin(entry) + + events := uint32(wepoll.AFD_POLL_RECEIVE | wepoll.AFD_POLL_DISCONNECT | wepoll.AFD_POLL_ABORT | wepoll.AFD_POLL_LOCAL_CLOSE) + err = p.afd.Poll(baseHandle, events, &entry.ioStatusBlock, &entry.pollInfo) + if err != nil { + entry.pinner.Unpin() + return err + } + + p.entries[fd] = entry + + if !p.running { + p.running = true + p.wg.Add(1) + go p.run() + } + + return nil +} + +func (p *FDPoller) Remove(fd int) { + p.mutex.Lock() + defer p.mutex.Unlock() + + entry, ok := p.entries[fd] + if !ok { + return + } + + entry.cancelled = true + if p.afd != nil { + p.afd.Cancel(&entry.ioStatusBlock) + } + + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, fd) +} + +func (p *FDPoller) wakeup() { + windows.PostQueuedCompletionStatus(p.iocp, 0, 0, nil) +} + +func (p *FDPoller) Close() error { + p.mutex.Lock() + p.closed.Store(true) + p.mutex.Unlock() + + p.cancel() + p.wakeup() + p.wg.Wait() + + p.mutex.Lock() + defer p.mutex.Unlock() + + for fd, entry := range p.entries { + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, fd) + } + + if p.afd != nil { + p.afd.Close() + p.afd = nil + } + if p.iocp != 0 { + windows.CloseHandle(p.iocp) + p.iocp = 0 + } + return nil +} + +func (p *FDPoller) drainCompletions(completions []wepoll.OverlappedEntry) { + for { + var numRemoved uint32 + err := wepoll.GetQueuedCompletionStatusEx(p.iocp, &completions[0], uint32(len(completions)), &numRemoved, 0, false) + if err != nil || numRemoved == 0 { + break + } + + for i := uint32(0); i < numRemoved; i++ { + event := completions[i] + if event.Overlapped == nil { + continue + } + + entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) + p.mutex.Lock() + if p.entries[entry.fd] == entry && !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, entry.fd) + p.mutex.Unlock() + } + } +} + +func (p *FDPoller) run() { + defer p.wg.Done() + + completions := make([]wepoll.OverlappedEntry, 64) + + for { + select { + case <-p.ctx.Done(): + p.drainCompletions(completions) + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + default: + } + + var numRemoved uint32 + err := wepoll.GetQueuedCompletionStatusEx(p.iocp, &completions[0], 64, &numRemoved, windows.INFINITE, false) + if err != nil { + p.mutex.Lock() + p.running = false + p.mutex.Unlock() + return + } + + for i := uint32(0); i < numRemoved; i++ { + event := completions[i] + + if event.Overlapped == nil { + continue + } + + entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped)) + + p.mutex.Lock() + + if p.entries[entry.fd] != entry { + p.mutex.Unlock() + continue + } + + if !entry.unpinned { + entry.unpinned = true + entry.pinner.Unpin() + } + delete(p.entries, entry.fd) + + if entry.cancelled { + p.mutex.Unlock() + continue + } + + if uint32(entry.ioStatusBlock.Status) == wepoll.STATUS_CANCELLED { + p.mutex.Unlock() + continue + } + + events := entry.pollInfo.Handles[0].Events + if events&(wepoll.AFD_POLL_RECEIVE|wepoll.AFD_POLL_DISCONNECT|wepoll.AFD_POLL_ABORT|wepoll.AFD_POLL_LOCAL_CLOSE) == 0 { + p.mutex.Unlock() + continue + } + + p.mutex.Unlock() + go entry.handler.HandleFDEvent() + } + + p.mutex.Lock() + if len(p.entries) == 0 { + p.running = false + p.mutex.Unlock() + return + } + p.mutex.Unlock() + } +} diff --git a/common/bufio/packet_reactor.go b/common/bufio/packet_reactor.go new file mode 100644 index 00000000..3347d2c2 --- /dev/null +++ b/common/bufio/packet_reactor.go @@ -0,0 +1,507 @@ +package bufio + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const ( + batchReadTimeout = 5 *time.Second +) + +const ( + stateIdle int32 = 0 + stateActive int32 = 1 + stateClosed int32 = 2 +) + +type withoutReadDeadline interface { + NeedAdditionalReadDeadline() bool +} + +func needAdditionalReadDeadline(rawReader any) bool { + if deadlineReader, loaded := rawReader.(withoutReadDeadline); loaded { + return deadlineReader.NeedAdditionalReadDeadline() + } + if upstream, hasUpstream := rawReader.(N.WithUpstreamReader); hasUpstream { + return needAdditionalReadDeadline(upstream.UpstreamReader()) + } + if upstream, hasUpstream := rawReader.(common.WithUpstream); hasUpstream { + return needAdditionalReadDeadline(upstream.Upstream()) + } + return false +} + +func CreatePacketPushable(reader N.PacketReader) (N.PacketPushable, bool) { + if pushable, ok := reader.(N.PacketPushable); ok { + return pushable, true + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + if upstream, ok := u.UpstreamReader().(N.PacketReader); ok { + return CreatePacketPushable(upstream) + } + } + if u, ok := reader.(common.WithUpstream); ok { + if upstream, ok := u.Upstream().(N.PacketReader); ok { + return CreatePacketPushable(upstream) + } + } + return nil, false +} + +func CreatePacketPollable(reader N.PacketReader) (N.PacketPollable, bool) { + if pollable, ok := reader.(N.PacketPollable); ok { + return pollable, true + } + if creator, ok := reader.(N.PacketPollableCreator); ok { + return creator.CreatePacketPollable() + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + if upstream, ok := u.UpstreamReader().(N.PacketReader); ok { + return CreatePacketPollable(upstream) + } + } + if u, ok := reader.(common.WithUpstream); ok { + if upstream, ok := u.Upstream().(N.PacketReader); ok { + return CreatePacketPollable(upstream) + } + } + return nil, false +} + +type PacketReactor struct { + ctx context.Context + cancel context.CancelFunc + logger logger.Logger + fdPoller *FDPoller + fdPollerOnce sync.Once + fdPollerErr error +} + +func NewPacketReactor(ctx context.Context, l logger.Logger) *PacketReactor { + ctx, cancel := context.WithCancel(ctx) + if l == nil { + l = logger.NOP() + } + return &PacketReactor{ + ctx: ctx, + cancel: cancel, + logger: l, + } +} + +func (r *PacketReactor) getFDPoller() (*FDPoller, error) { + r.fdPollerOnce.Do(func() { + r.fdPoller, r.fdPollerErr = NewFDPoller(r.ctx) + }) + return r.fdPoller, r.fdPollerErr +} + +func (r *PacketReactor) Close() error { + r.cancel() + if r.fdPoller != nil { + return r.fdPoller.Close() + } + return nil +} + +type reactorConnection struct { + ctx context.Context + cancel context.CancelFunc + reactor *PacketReactor + onClose N.CloseHandlerFunc + upload *reactorStream + download *reactorStream + stopReactorWatch func() bool + + closeOnce sync.Once + done chan struct{} + err error +} + +type reactorStream struct { + connection *reactorConnection + + source N.PacketReader + destination N.PacketWriter + originSource N.PacketReader + + pushable N.PacketPushable + pushableCallback func() + pollable N.PacketPollable + deadlineSetter deadlineSetter + options N.ReadWaitOptions + readWaiter N.PacketReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc + + state atomic.Int32 +} + +func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destination N.PacketConn, onClose N.CloseHandlerFunc) { + r.logger.Trace("packet copy: starting") + ctx, cancel := context.WithCancel(ctx) + conn := &reactorConnection{ + ctx: ctx, + cancel: cancel, + reactor: r, + onClose: onClose, + done: make(chan struct{}), + } + conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() { + conn.closeWithError(r.ctx.Err()) + }) + + conn.upload = r.prepareStream(conn, source, destination) + select { + case <-conn.done: + return + default: + } + + conn.download = r.prepareStream(conn, destination, source) + select { + case <-conn.done: + return + default: + } + + r.registerStream(conn.upload) + r.registerStream(conn.download) +} + +func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketReader, destination N.PacketWriter) *reactorStream { + stream := &reactorStream{ + connection: conn, + source: source, + destination: destination, + originSource: source, + } + + for { + source, stream.readCounters = N.UnwrapCountPacketReader(source, stream.readCounters) + destination, stream.writeCounters = N.UnwrapCountPacketWriter(destination, stream.writeCounters) + if cachedReader, isCached := source.(N.CachedPacketReader); isCached { + packet := cachedReader.ReadCachedPacket() + if packet != nil { + buffer := packet.Buffer + dataLen := buffer.Len() + err := destination.WritePacket(buffer, packet.Destination) + N.PutPacketBuffer(packet) + if err != nil { + buffer.Leak() + conn.closeWithError(err) + return stream + } + for _, counter := range stream.readCounters { + counter(int64(dataLen)) + } + for _, counter := range stream.writeCounters { + counter(int64(dataLen)) + } + continue + } + } + break + } + stream.source = source + stream.destination = destination + + stream.options = N.NewReadWaitOptions(source, destination) + + stream.readWaiter, _ = CreatePacketReadWaiter(source) + if stream.readWaiter != nil { + needCopy := stream.readWaiter.InitializeReadWaiter(stream.options) + if needCopy { + stream.readWaiter = nil + } + } + + stream.pushable, _ = CreatePacketPushable(source) + if stream.pushable == nil { + stream.pollable, _ = CreatePacketPollable(source) + } + + return stream +} + +func (r *PacketReactor) registerStream(stream *reactorStream) { + if needAdditionalReadDeadline(stream.source) { + r.logger.Trace("packet stream: needs additional deadline handling, using legacy copy") + go stream.runLegacyCopy() + return + } + + // Check if deadline setter is available and functional + if setter, ok := stream.source.(deadlineSetter); ok { + err := setter.SetReadDeadline(time.Time{}) + if err != nil { + r.logger.Trace("packet stream: SetReadDeadline not supported, using legacy copy") + go stream.runLegacyCopy() + return + } + stream.deadlineSetter = setter + } else { + r.logger.Trace("packet stream: no deadline setter, using legacy copy") + go stream.runLegacyCopy() + return + } + + if stream.pushable != nil { + r.logger.Trace("packet stream: using pushable mode") + stream.pushableCallback = func() { + go stream.runActiveLoop(nil) + } + stream.pushable.SetOnDataReady(stream.pushableCallback) + if stream.pushable.HasPendingData() { + go stream.runActiveLoop(nil) + } + return + } + + if stream.pollable == nil { + r.logger.Trace("packet stream: using legacy copy") + go stream.runLegacyCopy() + return + } + + fdPoller, err := r.getFDPoller() + if err != nil { + r.logger.Trace("packet stream: FD poller unavailable, using legacy copy") + go stream.runLegacyCopy() + return + } + err = fdPoller.Add(stream, stream.pollable.FD()) + if err != nil { + r.logger.Trace("packet stream: failed to add to FD poller, using legacy copy") + go stream.runLegacyCopy() + } else { + r.logger.Trace("packet stream: registered with FD poller") + } +} + +func (s *reactorStream) runActiveLoop(firstPacket *N.PacketBuffer) { + if s.source == nil { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + if !s.state.CompareAndSwap(stateIdle, stateActive) { + if firstPacket != nil { + firstPacket.Buffer.Release() + N.PutPacketBuffer(firstPacket) + } + return + } + + if s.pushable != nil { + s.pushable.SetOnDataReady(nil) + } + + notFirstTime := false + + if firstPacket != nil { + err := s.writePacketWithCounters(firstPacket) + if err != nil { + s.closeWithError(err) + return + } + notFirstTime = true + } + + for { + if s.state.Load() == stateClosed { + return + } + + deadlineErr := s.deadlineSetter.SetReadDeadline(time.Now().Add(batchReadTimeout)) + if deadlineErr != nil { + s.connection.reactor.logger.Trace("packet stream: SetReadDeadline failed, switching to legacy copy") + s.state.Store(stateIdle) + go s.runLegacyCopy() + return + } + + var ( + buffer *N.PacketBuffer + destination M.Socksaddr + err error + ) + + if s.readWaiter != nil { + var readBuffer *buf.Buffer + readBuffer, destination, err = s.readWaiter.WaitReadPacket() + if readBuffer != nil { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } else { + readBuffer := s.options.NewPacketBuffer() + destination, err = s.source.ReadPacket(readBuffer) + if err != nil { + readBuffer.Release() + } else { + buffer = N.NewPacketBuffer() + buffer.Buffer = readBuffer + buffer.Destination = destination + } + } + + if err != nil { + if E.IsTimeout(err) { + s.deadlineSetter.SetReadDeadline(time.Time{}) + if !s.state.CompareAndSwap(stateActive, stateIdle) { + return + } + s.connection.reactor.logger.Trace("packet stream: timeout, returning to idle pool") + if s.pushable != nil { + s.pushable.SetOnDataReady(s.pushableCallback) + if s.pushable.HasPendingData() { + if s.state.CompareAndSwap(stateIdle, stateActive) { + s.pushable.SetOnDataReady(nil) + continue + } + } + return + } + s.returnToPool() + return + } + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.connection.reactor.logger.Trace("packet stream: error occurred: ", err) + s.closeWithError(err) + return + } + + err = s.writePacketWithCounters(buffer) + if err != nil { + if !notFirstTime { + err = N.ReportHandshakeFailure(s.originSource, err) + } + s.closeWithError(err) + return + } + notFirstTime = true + } +} + +func (s *reactorStream) writePacketWithCounters(packet *N.PacketBuffer) error { + buffer := packet.Buffer + destination := packet.Destination + dataLen := buffer.Len() + + s.options.PostReturn(buffer) + err := s.destination.WritePacket(buffer, destination) + N.PutPacketBuffer(packet) + if err != nil { + buffer.Leak() + return err + } + + for _, counter := range s.readCounters { + counter(int64(dataLen)) + } + for _, counter := range s.writeCounters { + counter(int64(dataLen)) + } + return nil +} + +func (s *reactorStream) returnToPool() { + if s.state.Load() != stateIdle { + return + } + + if s.pollable == nil || s.connection.reactor.fdPoller == nil { + return + } + + fd := s.pollable.FD() + err := s.connection.reactor.fdPoller.Add(s, fd) + if err != nil { + s.closeWithError(err) + return + } + if s.state.Load() != stateIdle { + s.connection.reactor.fdPoller.Remove(fd) + } +} + +func (s *reactorStream) HandleFDEvent() { + s.runActiveLoop(nil) +} + +func (s *reactorStream) runLegacyCopy() { + _, err := CopyPacketWithCounters(s.destination, s.source, s.originSource, s.readCounters, s.writeCounters) + s.closeWithError(err) +} + +func (s *reactorStream) closeWithError(err error) { + s.connection.closeWithError(err) +} + +func (c *reactorConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + defer close(c.done) + c.reactor.logger.Trace("packet connection: closing with error: ", err) + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + + c.err = err + c.cancel() + + if c.upload != nil { + c.upload.state.Store(stateClosed) + } + if c.download != nil { + c.download.state.Store(stateClosed) + } + + c.removeFromPollers() + + if c.upload != nil { + common.Close(c.upload.originSource) + } + if c.download != nil { + common.Close(c.download.originSource) + } + + if c.onClose != nil { + c.onClose(c.err) + } + }) +} + +func (c *reactorConnection) removeFromPollers() { + c.removeStreamFromPoller(c.upload) + c.removeStreamFromPoller(c.download) +} + +func (c *reactorConnection) removeStreamFromPoller(stream *reactorStream) { + if stream == nil || stream.pollable == nil || c.reactor.fdPoller == nil { + return + } + c.reactor.fdPoller.Remove(stream.pollable.FD()) +} diff --git a/common/bufio/packet_reactor_test.go b/common/bufio/packet_reactor_test.go new file mode 100644 index 00000000..5152ed43 --- /dev/null +++ b/common/bufio/packet_reactor_test.go @@ -0,0 +1,2148 @@ +//go:build darwin || linux || windows + +package bufio + +import ( + "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" + "net" + "os" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testPacketPipe struct { + inChan chan *N.PacketBuffer + outChan chan *N.PacketBuffer + localAddr M.Socksaddr + closed atomic.Bool + closeOnce sync.Once + done chan struct{} +} + +func newTestPacketPipe(localAddr M.Socksaddr) *testPacketPipe { + return &testPacketPipe{ + inChan: make(chan *N.PacketBuffer, 256), + outChan: make(chan *N.PacketBuffer, 256), + localAddr: localAddr, + done: make(chan struct{}), + } +} + +func (p *testPacketPipe) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case packet, ok := <-p.inChan: + if !ok { + return M.Socksaddr{}, io.EOF + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err + case <-p.done: + return M.Socksaddr{}, net.ErrClosed + } +} + +func (p *testPacketPipe) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if p.closed.Load() { + buffer.Release() + return net.ErrClosed + } + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(buffer.Len()) + newBuf.Write(buffer.Bytes()) + packet.Buffer = newBuf + packet.Destination = destination + buffer.Release() + select { + case p.outChan <- packet: + return nil + case <-p.done: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return net.ErrClosed + } +} + +func (p *testPacketPipe) Close() error { + p.closeOnce.Do(func() { + p.closed.Store(true) + close(p.done) + }) + return nil +} + +func (p *testPacketPipe) LocalAddr() net.Addr { + return p.localAddr.UDPAddr() +} + +func (p *testPacketPipe) SetDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetReadDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) SetWriteDeadline(t time.Time) error { + return nil +} + +func (p *testPacketPipe) send(data []byte, destination M.Socksaddr) { + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(len(data)) + newBuf.Write(data) + packet.Buffer = newBuf + packet.Destination = destination + p.inChan <- packet +} + +func (p *testPacketPipe) receive() (*N.PacketBuffer, bool) { + select { + case packet, ok := <-p.outChan: + return packet, ok + case <-p.done: + return nil, false + } +} + +type fdPacketConn struct { + N.NetPacketConn + fd int + targetAddr M.Socksaddr +} + +func newFDPacketConn(t *testing.T, conn net.PacketConn, targetAddr M.Socksaddr) *fdPacketConn { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok, "connection must implement syscall.Conn") + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return &fdPacketConn{ + NetPacketConn: NewPacketConn(conn), + fd: fd, + targetAddr: targetAddr, + } +} + +func (c *fdPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} + +func (c *fdPacketConn) FD() int { + return c.fd +} + +type channelPacketConn struct { + N.NetPacketConn + packetChan chan *N.PacketBuffer + done chan struct{} + closeOnce sync.Once + targetAddr M.Socksaddr + deadlineLock sync.Mutex + deadline time.Time + deadlineChan chan struct{} +} + +func newChannelPacketConn(conn net.PacketConn, targetAddr M.Socksaddr) *channelPacketConn { + c := &channelPacketConn{ + NetPacketConn: NewPacketConn(conn), + packetChan: make(chan *N.PacketBuffer, 256), + done: make(chan struct{}), + targetAddr: targetAddr, + deadlineChan: make(chan struct{}), + } + go c.readLoop() + return c +} + +func (c *channelPacketConn) readLoop() { + for { + select { + case <-c.done: + return + default: + } + buffer := buf.NewPacket() + _, err := c.NetPacketConn.ReadPacket(buffer) + if err != nil { + buffer.Release() + close(c.packetChan) + return + } + packet := N.NewPacketBuffer() + packet.Buffer = buffer + packet.Destination = c.targetAddr + select { + case c.packetChan <- packet: + case <-c.done: + buffer.Release() + N.PutPacketBuffer(packet) + return + } + } +} + +func (c *channelPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + c.deadlineLock.Lock() + deadline := c.deadline + deadlineChan := c.deadlineChan + c.deadlineLock.Unlock() + + var timer <-chan time.Time + if !deadline.IsZero() { + d := time.Until(deadline) + if d <= 0 { + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + t := time.NewTimer(d) + defer t.Stop() + timer = t.C + } + + select { + case packet, ok := <-c.packetChan: + if !ok { + return M.Socksaddr{}, net.ErrClosed + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + case <-c.done: + return M.Socksaddr{}, net.ErrClosed + case <-deadlineChan: + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-timer: + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *channelPacketConn) SetReadDeadline(t time.Time) error { + c.deadlineLock.Lock() + c.deadline = t + if c.deadlineChan != nil { + close(c.deadlineChan) + } + c.deadlineChan = make(chan struct{}) + c.deadlineLock.Unlock() + return nil +} + +func (c *channelPacketConn) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return c.NetPacketConn.Close() +} + +type batchHashPair struct { + sendHash map[int][]byte + recvHash map[int][]byte +} + +func TestBatchCopy_Pipe_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, sendHash, recvHash, "data mismatch") +} + +func TestBatchCopy_Pipe_Bidirectional(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 10001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 10002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + packet, ok := pipeA.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeB.send(data, addr1) + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var aPair, bPair batchHashPair + select { + case aPair = <-pingCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for A") + } + select { + case bPair = <-pongCh: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for B") + } + + assert.Equal(t, aPair.sendHash, bPair.recvHash, "A->B mismatch") + assert.Equal(t, bPair.sendHash, aPair.recvHash, "B->A mismatch") +} + +func TestBatchCopy_FDPoller_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newFDPacketConn(t, proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_LegacyChannel_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newChannelPacketConn(proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MixedMode_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newChannelPacketConn(proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_MultipleConnections_DataIntegrity(t *testing.T) { + t.Parallel() + + const numConnections = 5 + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + var wg sync.WaitGroup + errCh := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20000+idx*2)) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", uint16(20001+idx*2)) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 20 + const chunkSize = 1000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + errCh <- errors.New("timeout") + return + } + + for k, v := range sendHash { + if rv, ok := recvHash[k]; !ok || string(v) != string(rv) { + errCh <- errors.New("data mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } +} + +func TestBatchCopy_TimeoutAndResume_DataIntegrity(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 30001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 30002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sendAndVerify := func(batchID int, count int) { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < count; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(1))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < count; i++ { + data := make([]byte, 1000) + rand.Read(data[2:]) + data[0] = byte(batchID) + data[1] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(5 * time.Second): + t.Fatalf("batch %d timeout", batchID) + } + + assert.Equal(t, sendHash, recvHash, "batch %d mismatch", batchID) + } + + sendAndVerify(1, 10) + + time.Sleep(350 * time.Millisecond) + + sendAndVerify(2, 10) +} + +func TestBatchCopy_CloseWhileTransferring(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 40001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 40002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background(), nil) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + stopSend := make(chan struct{}) + go func() { + for { + select { + case <-stopSend: + return + default: + data := make([]byte, 1000) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + close(stopSend) + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("copier did not close - possible deadlock") + } +} + +func TestBatchCopy_HighThroughput(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 50001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 50002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 500 + const chunkSize = 8000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + var mu sync.Mutex + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + t.Logf("recv channel closed at %d", i) + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + mu.Lock() + recvHash[idx] = hash[:] + mu.Unlock() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[2:]) + data[0] = byte(i >> 8) + data[1] = byte(i & 0xff) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(30 * time.Second): + t.Fatal("high throughput test timeout") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, len(sendHash), len(recvHash), "packet count mismatch") + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_LegacyFallback_DataIntegrity(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + proxyBAddr := M.SocksaddrFromNet(proxyBConn.LocalAddr()) + + proxyA := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyAConn), targetAddr: serverAddr} + proxyB := &legacyPacketConn{NetPacketConn: NewPacketConn(proxyBConn), targetAddr: clientAddr} + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 50 + const chunkSize = 9000 + + pingCh := make(chan batchHashPair, 1) + pongCh := make(chan batchHashPair, 1) + errCh := make(chan error, 2) + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + clientConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := clientConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("client read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + pingCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + go func() { + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + recvBuf := make([]byte, 65536) + + serverConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + for i := 0; i < times; i++ { + n, _, err := serverConn.ReadFrom(recvBuf) + if err != nil { + if os.IsTimeout(err) { + t.Logf("server read timeout after %d packets", i) + } + errCh <- err + return + } + hash := md5.Sum(recvBuf[:n]) + recvHash[int(recvBuf[0])] = hash[:] + } + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + _, err := serverConn.WriteTo(data, proxyBAddr.UDPAddr()) + if err != nil { + errCh <- err + return + } + time.Sleep(5 * time.Millisecond) + } + + pongCh <- batchHashPair{sendHash: sendHash, recvHash: recvHash} + }() + + var clientPair, serverPair batchHashPair + for i := 0; i < 2; i++ { + select { + case clientPair = <-pingCh: + case serverPair = <-pongCh: + case err := <-errCh: + t.Fatal(err) + case <-time.After(20 * time.Second): + t.Fatal("timeout") + } + } + + assert.Equal(t, clientPair.sendHash, serverPair.recvHash, "client->server mismatch") + assert.Equal(t, serverPair.sendHash, clientPair.recvHash, "server->client mismatch") +} + +func TestBatchCopy_ReactorClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60002) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + go func() { + for { + select { + case <-copyDone: + return + default: + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(10 * time.Millisecond) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + copier.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after reactor close") + } +} + +func TestBatchCopy_SmallPackets(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60011) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60012) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + const totalPackets = 20 + receivedCount := 0 + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + receivedCount++ + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < totalPackets; i++ { + size := (i % 10) + 1 + data := make([]byte, size) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for packets") + } + + assert.Equal(t, totalPackets, receivedCount) +} + +func TestBatchCopy_VaryingPacketSizes(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60041) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60042) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + }() + + time.Sleep(50 * time.Millisecond) + + sizes := []int{10, 100, 500, 1000, 2000, 4000, 8000} + const times = 3 + + totalPackets := len(sizes) * times + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < totalPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + idx := int(packet.Buffer.Byte(0))<<8 | int(packet.Buffer.Byte(1)) + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[idx] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + packetIdx := 0 + for _, size := range sizes { + for j := 0; j < times; j++ { + data := make([]byte, size) + rand.Read(data[2:]) + data[0] = byte(packetIdx >> 8) + data[1] = byte(packetIdx & 0xff) + hash := md5.Sum(data) + sendHash[packetIdx] = hash[:] + pipeA.send(data, addr2) + packetIdx++ + time.Sleep(5 * time.Millisecond) + } + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + + assert.Equal(t, len(sendHash), len(recvHash)) + for k, v := range sendHash { + assert.Equal(t, v, recvHash[k], "packet %d mismatch", k) + } +} + +func TestBatchCopy_OnCloseCallback(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60021) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60022) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + callbackCalled := make(chan error, 1) + onClose := func(err error) { + select { + case callbackCalled <- err: + default: + } + } + + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + pipeB.Close() + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + t.Fatal("onClose callback was not called") + } +} + +func TestBatchCopy_SourceClose(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 60031) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 60032) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + callbackCalled := make(chan struct{}) + onClose := func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(callbackCalled) + } + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, onClose) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 5; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + } + + time.Sleep(50 * time.Millisecond) + + pipeA.Close() + close(pipeA.inChan) + + select { + case <-callbackCalled: + case <-time.After(5 * time.Second): + pipeB.Close() + t.Fatal("onClose callback was not called after source close") + } + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("Copy did not return after source close") + } + + pipeB.Close() + + errMu.Lock() + err := capturedErr + errMu.Unlock() + + require.NotNil(t, err) +} + +type legacyPacketConn struct { + N.NetPacketConn + targetAddr M.Socksaddr +} + +func (c *legacyPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + _, err = c.NetPacketConn.ReadPacket(buffer) + if err != nil { + return M.Socksaddr{}, err + } + return c.targetAddr, nil +} + +// pushablePacketPipe implements PacketPushable for testing with deadline support +type pushablePacketPipe struct { + inChan chan *N.PacketBuffer + outChan chan *N.PacketBuffer + localAddr M.Socksaddr + closed atomic.Bool + closeOnce sync.Once + done chan struct{} + onDataReady func() + mutex sync.Mutex + deadlineLock sync.Mutex + deadline time.Time + deadlineChan chan struct{} +} + +func newPushablePacketPipe(localAddr M.Socksaddr) *pushablePacketPipe { + return &pushablePacketPipe{ + inChan: make(chan *N.PacketBuffer, 256), + outChan: make(chan *N.PacketBuffer, 256), + localAddr: localAddr, + done: make(chan struct{}), + deadlineChan: make(chan struct{}), + } +} + +func (p *pushablePacketPipe) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + p.deadlineLock.Lock() + deadline := p.deadline + deadlineChan := p.deadlineChan + p.deadlineLock.Unlock() + + var timer <-chan time.Time + if !deadline.IsZero() { + d := time.Until(deadline) + if d <= 0 { + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + t := time.NewTimer(d) + defer t.Stop() + timer = t.C + } + + select { + case packet, ok := <-p.inChan: + if !ok { + return M.Socksaddr{}, io.EOF + } + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return destination, err + case <-p.done: + return M.Socksaddr{}, net.ErrClosed + case <-deadlineChan: + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-timer: + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (p *pushablePacketPipe) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if p.closed.Load() { + buffer.Release() + return net.ErrClosed + } + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(buffer.Len()) + newBuf.Write(buffer.Bytes()) + packet.Buffer = newBuf + packet.Destination = destination + buffer.Release() + select { + case p.outChan <- packet: + return nil + case <-p.done: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return net.ErrClosed + } +} + +func (p *pushablePacketPipe) Close() error { + p.closeOnce.Do(func() { + p.closed.Store(true) + close(p.done) + }) + return nil +} + +func (p *pushablePacketPipe) LocalAddr() net.Addr { + return p.localAddr.UDPAddr() +} + +func (p *pushablePacketPipe) SetDeadline(t time.Time) error { + return p.SetReadDeadline(t) +} + +func (p *pushablePacketPipe) SetReadDeadline(t time.Time) error { + p.deadlineLock.Lock() + p.deadline = t + if p.deadlineChan != nil { + close(p.deadlineChan) + } + p.deadlineChan = make(chan struct{}) + p.deadlineLock.Unlock() + return nil +} + +func (p *pushablePacketPipe) SetWriteDeadline(t time.Time) error { + return nil +} + +func (p *pushablePacketPipe) SetOnDataReady(callback func()) { + p.mutex.Lock() + p.onDataReady = callback + p.mutex.Unlock() +} + +func (p *pushablePacketPipe) HasPendingData() bool { + return len(p.inChan) > 0 +} + +func (p *pushablePacketPipe) triggerDataReady() { + p.mutex.Lock() + callback := p.onDataReady + p.mutex.Unlock() + if callback != nil { + callback() + } +} + +func (p *pushablePacketPipe) send(data []byte, destination M.Socksaddr) { + packet := N.NewPacketBuffer() + newBuf := buf.NewSize(len(data)) + newBuf.Write(data) + packet.Buffer = newBuf + packet.Destination = destination + p.inChan <- packet +} + +func (p *pushablePacketPipe) sendWithNotify(data []byte, destination M.Socksaddr) { + p.send(data, destination) + p.triggerDataReady() +} + +func (p *pushablePacketPipe) receive() (*N.PacketBuffer, bool) { + select { + case packet, ok := <-p.outChan: + return packet, ok + case <-p.done: + return nil, false + } +} + +// failingPacketWriter fails after N writes +type failingPacketWriter struct { + N.PacketConn + failAfter int + writeCount atomic.Int32 +} + +func (w *failingPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if w.writeCount.Add(1) > int32(w.failAfter) { + buffer.Release() + return errors.New("simulated packet write error") + } + return w.PacketConn.WritePacket(buffer, destination) +} + +// errorPacketReader returns error on ReadPacket +type errorPacketReader struct { + N.PacketConn + readError error +} + +func (r *errorPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + return M.Socksaddr{}, r.readError +} + +func TestPacketReactor_Pushable_Basic(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61001) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61002) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + const times = 20 + const chunkSize = 1000 + + sendHash := make(map[int][]byte) + recvHash := make(map[int][]byte) + + recvDone := make(chan struct{}) + go func() { + defer close(recvDone) + for i := 0; i < times; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + hash := md5.Sum(packet.Buffer.Bytes()) + recvHash[int(packet.Buffer.Byte(0))] = hash[:] + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < times; i++ { + data := make([]byte, chunkSize) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + sendHash[i] = hash[:] + pipeA.sendWithNotify(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, sendHash, recvHash, "data mismatch") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_Pushable_HasPendingData(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61011) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61012) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Pre-load data before starting the reactor + const preloadCount = 5 + preloadHashes := make([][]byte, preloadCount) + for i := 0; i < preloadCount; i++ { + data := make([]byte, 100) + rand.Read(data[1:]) + data[0] = byte(i) + hash := md5.Sum(data) + preloadHashes[i] = hash[:] + pipeA.send(data, addr2) + } + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + // Wait for the reactor to detect pending data and process it + receivedHashes := make([][]byte, 0, preloadCount) + timeout := time.After(5 * time.Second) + + for i := 0; i < preloadCount; i++ { + select { + case packet := <-pipeB.outChan: + hash := md5.Sum(packet.Buffer.Bytes()) + receivedHashes = append(receivedHashes, hash[:]) + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-timeout: + t.Fatalf("timeout: only received %d/%d packets", len(receivedHashes), preloadCount) + } + } + + // Verify all preloaded packets were received + assert.Equal(t, len(preloadHashes), len(receivedHashes), "should receive all preloaded packets") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_Pushable_TimeoutResume(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61021) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61022) + + pipeA := newPushablePacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send first batch + const batchSize = 5 + for i := 0; i < batchSize; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.sendWithNotify(data, addr2) + } + + // Receive first batch + for i := 0; i < batchSize; i++ { + select { + case packet := <-pipeB.outChan: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-time.After(5 * time.Second): + t.Fatalf("timeout receiving first batch packet %d", i) + } + } + + // Wait longer than the batch timeout (250ms) to trigger return to idle + time.Sleep(400 * time.Millisecond) + + // Send second batch - should trigger data ready callback + for i := 0; i < batchSize; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.sendWithNotify(data, addr2) + } + + // Receive second batch + for i := 0; i < batchSize; i++ { + select { + case packet := <-pipeB.outChan: + packet.Buffer.Release() + N.PutPacketBuffer(packet) + case <-time.After(5 * time.Second): + t.Fatalf("timeout receiving second batch packet %d", i) + } + } + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_WriteError(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61031) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61032) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Wrap destination with failing writer that fails after 3 packets + failingDest := &failingPacketWriter{PacketConn: pipeB, failAfter: 3} + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + closeDone := make(chan struct{}) + + go func() { + copier.Copy(context.Background(), pipeA, failingDest, func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(closeDone) + }) + }() + + time.Sleep(50 * time.Millisecond) + + // Send packets until failure + for i := 0; i < 10; i++ { + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(10 * time.Millisecond) + } + + // Wait for close callback + select { + case <-closeDone: + errMu.Lock() + err := capturedErr + errMu.Unlock() + assert.NotNil(t, err, "expected error to be propagated") + assert.Contains(t, err.Error(), "simulated packet write error") + case <-time.After(5 * time.Second): + pipeA.Close() + pipeB.Close() + t.Fatal("timeout waiting for close callback") + } +} + +func TestPacketReactor_ReadError(t *testing.T) { + t.Parallel() + + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61042) + + pipeA := newTestPacketPipe(M.ParseSocksaddrHostPort("127.0.0.1", 61041)) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + // Wrap source with error-returning reader + readErr := errors.New("simulated packet read error") + errorSrc := &errorPacketReader{PacketConn: pipeA, readError: readErr} + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + var capturedErr error + var errMu sync.Mutex + closeDone := make(chan struct{}) + + go func() { + copier.Copy(context.Background(), errorSrc, pipeB, func(err error) { + errMu.Lock() + capturedErr = err + errMu.Unlock() + close(closeDone) + }) + }() + + select { + case <-closeDone: + errMu.Lock() + err := capturedErr + errMu.Unlock() + assert.NotNil(t, err, "expected error to be propagated") + assert.Contains(t, err.Error(), "simulated packet read error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestPacketReactor_StateMachine_ConcurrentWakeup(t *testing.T) { + t.Parallel() + + clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer clientConn.Close() + + proxyAConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + proxyBConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + + serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer serverConn.Close() + + serverAddr := M.SocksaddrFromNet(serverConn.LocalAddr()) + clientAddr := M.SocksaddrFromNet(clientConn.LocalAddr()) + proxyAAddr := M.SocksaddrFromNet(proxyAConn.LocalAddr()) + + proxyA := newFDPacketConn(t, proxyAConn, serverAddr) + proxyB := newFDPacketConn(t, proxyBConn, clientAddr) + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), proxyA, proxyB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send many packets rapidly to stress the state machine + const numPackets = 100 + var wg sync.WaitGroup + + // Multiple goroutines sending packets concurrently + for g := 0; g < 5; g++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + for i := 0; i < numPackets/5; i++ { + data := make([]byte, 100) + rand.Read(data) + clientConn.WriteTo(data, proxyAAddr.UDPAddr()) + } + }(g) + } + + wg.Wait() + + // Give time for packets to be processed + time.Sleep(500 * time.Millisecond) + + // Cleanup + proxyAConn.Close() + proxyBConn.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} + +func TestPacketReactor_StateMachine_CloseWhileActive(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61051) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61052) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + + copier := NewPacketReactor(context.Background(), nil) + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Start sending packets continuously + stopSend := make(chan struct{}) + go func() { + for { + select { + case <-stopSend: + return + default: + data := make([]byte, 100) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(1 * time.Millisecond) + } + } + }() + + // Give time for active processing + time.Sleep(100 * time.Millisecond) + + // Close while actively processing + pipeA.Close() + pipeB.Close() + copier.Close() + close(stopSend) + + // Verify no deadlock + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("deadlock detected: Copy did not return after close") + } +} + +func TestPacketReactor_Counters(t *testing.T) { + t.Parallel() + + addr1 := M.ParseSocksaddrHostPort("127.0.0.1", 61061) + addr2 := M.ParseSocksaddrHostPort("127.0.0.1", 61062) + + pipeA := newTestPacketPipe(addr1) + pipeB := newTestPacketPipe(addr2) + defer pipeA.Close() + defer pipeB.Close() + + copier := NewPacketReactor(context.Background(), nil) + defer copier.Close() + + copyDone := make(chan struct{}) + go func() { + copier.Copy(context.Background(), pipeA, pipeB, nil) + close(copyDone) + }() + + time.Sleep(50 * time.Millisecond) + + // Send packets with known sizes + const numPackets = 10 + packetSizes := []int{100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} + totalBytes := 0 + for _, size := range packetSizes { + totalBytes += size + } + + recvDone := make(chan struct{}) + receivedBytes := 0 + go func() { + defer close(recvDone) + for i := 0; i < numPackets; i++ { + packet, ok := pipeB.receive() + if !ok { + return + } + receivedBytes += packet.Buffer.Len() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + } + }() + + for i := 0; i < numPackets; i++ { + data := make([]byte, packetSizes[i]) + rand.Read(data) + pipeA.send(data, addr2) + time.Sleep(5 * time.Millisecond) + } + + select { + case <-recvDone: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for receive") + } + + assert.Equal(t, totalBytes, receivedBytes, "total bytes received should match sent") + + pipeA.Close() + pipeB.Close() + + select { + case <-copyDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for copy to complete") + } +} diff --git a/common/bufio/stream_reactor.go b/common/bufio/stream_reactor.go new file mode 100644 index 00000000..da61e684 --- /dev/null +++ b/common/bufio/stream_reactor.go @@ -0,0 +1,490 @@ +package bufio + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + N "github.com/sagernet/sing/common/network" +) + +const ( + streamBatchReadTimeout = 5*time.Second +) + +func CreateStreamPollable(reader io.Reader) (N.StreamPollable, bool) { + if pollable, ok := reader.(N.StreamPollable); ok { + return pollable, true + } + if creator, ok := reader.(N.StreamPollableCreator); ok { + return creator.CreateStreamPollable() + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateStreamPollable(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateStreamPollable(u.Upstream().(io.Reader)) + } + return nil, false +} + +type StreamReactor struct { + ctx context.Context + cancel context.CancelFunc + logger logger.Logger + fdPoller *FDPoller + fdPollerOnce sync.Once + fdPollerErr error +} + +func NewStreamReactor(ctx context.Context, l logger.Logger) *StreamReactor { + ctx, cancel := context.WithCancel(ctx) + if l == nil { + l = logger.NOP() + } + return &StreamReactor{ + ctx: ctx, + cancel: cancel, + logger: l, + } +} + +func (r *StreamReactor) getFDPoller() (*FDPoller, error) { + r.fdPollerOnce.Do(func() { + r.fdPoller, r.fdPollerErr = NewFDPoller(r.ctx) + }) + return r.fdPoller, r.fdPollerErr +} + +func (r *StreamReactor) Close() error { + r.cancel() + if r.fdPoller != nil { + return r.fdPoller.Close() + } + return nil +} + +type streamConnection struct { + ctx context.Context + cancel context.CancelFunc + reactor *StreamReactor + onClose N.CloseHandlerFunc + upload *streamDirection + download *streamDirection + stopReactorWatch func() bool + + closeOnce sync.Once + done chan struct{} + err error +} + +type deadlineSetter interface { + SetReadDeadline(time.Time) error +} + +type streamDirection struct { + connection *streamConnection + + source io.Reader + destination io.Writer + originSource net.Conn + + pollable N.StreamPollable + deadlineSetter deadlineSetter + options N.ReadWaitOptions + readWaiter N.ReadWaiter + readCounters []N.CountFunc + writeCounters []N.CountFunc + + isUpload bool + state atomic.Int32 +} + +// Copy initiates bidirectional TCP copy with reactor optimization +// It uses splice when available for zero-copy, otherwise falls back to reactor mode +func (r *StreamReactor) Copy(ctx context.Context, source net.Conn, destination net.Conn, onClose N.CloseHandlerFunc) { + // Try splice first (zero-copy optimization) + if r.trySplice(ctx, source, destination, onClose) { + r.logger.Trace("stream copy: using splice for zero-copy") + return + } + + // Fall back to reactor mode + r.logger.Trace("stream copy: using reactor mode") + ctx, cancel := context.WithCancel(ctx) + conn := &streamConnection{ + ctx: ctx, + cancel: cancel, + reactor: r, + onClose: onClose, + done: make(chan struct{}), + } + conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() { + conn.closeWithError(r.ctx.Err()) + }) + + conn.upload = r.prepareDirection(conn, source, destination, source, true) + select { + case <-conn.done: + return + default: + } + + conn.download = r.prepareDirection(conn, destination, source, destination, false) + select { + case <-conn.done: + return + default: + } + + r.registerDirection(conn.upload) + r.registerDirection(conn.download) +} + +func (r *StreamReactor) trySplice(ctx context.Context, source net.Conn, destination net.Conn, onClose N.CloseHandlerFunc) bool { + if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) { + return false + } + + // Both ends support syscall, use traditional copy with splice + go func() { + err := CopyConn(ctx, source, destination) + if onClose != nil { + onClose(err) + } + }() + return true +} + +func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reader, destination io.Writer, originConn net.Conn, isUpload bool) *streamDirection { + direction := &streamDirection{ + connection: conn, + source: source, + destination: destination, + originSource: originConn, + isUpload: isUpload, + } + + // Unwrap counters and handle cached data + for { + source, direction.readCounters = N.UnwrapCountReader(source, direction.readCounters) + destination, direction.writeCounters = N.UnwrapCountWriter(destination, direction.writeCounters) + if cachedReader, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedReader.ReadCached() + if cachedBuffer != nil { + dataLen := cachedBuffer.Len() + _, err := destination.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err != nil { + conn.closeWithError(err) + return direction + } + for _, counter := range direction.readCounters { + counter(int64(dataLen)) + } + for _, counter := range direction.writeCounters { + counter(int64(dataLen)) + } + continue + } + } + break + } + direction.source = source + direction.destination = destination + + direction.options = N.NewReadWaitOptions(source, destination) + + direction.readWaiter, _ = CreateReadWaiter(source) + if direction.readWaiter != nil { + needCopy := direction.readWaiter.InitializeReadWaiter(direction.options) + if needCopy { + direction.readWaiter = nil + } + } + + // Try to get stream pollable for FD-based idle detection + direction.pollable, _ = CreateStreamPollable(source) + + return direction +} + +func (r *StreamReactor) registerDirection(direction *streamDirection) { + if needAdditionalReadDeadline(direction.source) { + r.logger.Trace("stream direction: needs additional deadline handling, using legacy copy") + go direction.runLegacyCopy() + return + } + + // Check if deadline setter is available and functional + if setter, ok := direction.originSource.(deadlineSetter); ok { + err := setter.SetReadDeadline(time.Time{}) + if err != nil { + r.logger.Trace("stream direction: SetReadDeadline not supported, using legacy copy") + go direction.runLegacyCopy() + return + } + direction.deadlineSetter = setter + } else { + r.logger.Trace("stream direction: no deadline setter, using legacy copy") + go direction.runLegacyCopy() + return + } + + // Check if there's buffered data that needs processing first + if direction.pollable != nil && direction.pollable.Buffered() > 0 { + r.logger.Trace("stream direction: has buffered data, starting active loop") + go direction.runActiveLoop() + return + } + + // Try to register with FD poller + if direction.pollable != nil { + fdPoller, err := r.getFDPoller() + if err == nil { + err = fdPoller.Add(direction, direction.pollable.FD()) + if err == nil { + r.logger.Trace("stream direction: registered with FD poller") + return + } + } + } + + // Fall back to legacy goroutine copy + r.logger.Trace("stream direction: using legacy copy") + go direction.runLegacyCopy() +} + +func (d *streamDirection) runActiveLoop() { + if d.source == nil { + return + } + if !d.state.CompareAndSwap(stateIdle, stateActive) { + return + } + + notFirstTime := false + + for { + if d.state.Load() == stateClosed { + return + } + + // Set batch read timeout + deadlineErr := d.deadlineSetter.SetReadDeadline(time.Now().Add(streamBatchReadTimeout)) + if deadlineErr != nil { + d.connection.reactor.logger.Trace("stream direction: SetReadDeadline failed, switching to legacy copy") + d.state.Store(stateIdle) + go d.runLegacyCopy() + return + } + + var ( + buffer *buf.Buffer + err error + ) + + if d.readWaiter != nil { + buffer, err = d.readWaiter.WaitReadBuffer() + } else { + buffer = d.options.NewBuffer() + err = NewExtendedReader(d.source).ReadBuffer(buffer) + if err != nil { + buffer.Release() + buffer = nil + } + } + + if err != nil { + if E.IsTimeout(err) { + // Timeout: check buffer and return to pool + d.deadlineSetter.SetReadDeadline(time.Time{}) + if d.state.CompareAndSwap(stateActive, stateIdle) { + d.connection.reactor.logger.Trace("stream direction: timeout, returning to idle pool") + d.returnToPool() + } + return + } + // EOF or error + if !notFirstTime { + err = N.ReportHandshakeFailure(d.originSource, err) + } + d.handleEOFOrError(err) + return + } + + err = d.writeBufferWithCounters(buffer) + if err != nil { + if !notFirstTime { + err = N.ReportHandshakeFailure(d.originSource, err) + } + d.closeWithError(err) + return + } + notFirstTime = true + } +} + +func (d *streamDirection) writeBufferWithCounters(buffer *buf.Buffer) error { + dataLen := buffer.Len() + d.options.PostReturn(buffer) + err := NewExtendedWriter(d.destination).WriteBuffer(buffer) + if err != nil { + buffer.Leak() + return err + } + + for _, counter := range d.readCounters { + counter(int64(dataLen)) + } + for _, counter := range d.writeCounters { + counter(int64(dataLen)) + } + return nil +} + +func (d *streamDirection) returnToPool() { + if d.state.Load() != stateIdle { + return + } + + // Critical: check if there's buffered data before returning to idle + if d.pollable != nil && d.pollable.Buffered() > 0 { + go d.runActiveLoop() + return + } + + // Safe to wait for FD events + if d.pollable != nil && d.connection.reactor.fdPoller != nil { + err := d.connection.reactor.fdPoller.Add(d, d.pollable.FD()) + if err != nil { + d.closeWithError(err) + return + } + if d.state.Load() != stateIdle { + d.connection.reactor.fdPoller.Remove(d.pollable.FD()) + } + } +} + +func (d *streamDirection) HandleFDEvent() { + d.runActiveLoop() +} + +func (d *streamDirection) runLegacyCopy() { + _, err := CopyWithCounters(d.destination, d.source, d.originSource, d.readCounters, d.writeCounters, DefaultIncreaseBufferAfter, DefaultBatchSize) + d.handleEOFOrError(err) +} + +func (d *streamDirection) handleEOFOrError(err error) { + if err == nil || err == io.EOF { + // Graceful EOF: close write direction only (half-close) + d.connection.reactor.logger.Trace("stream direction: graceful EOF, half-closing") + d.state.Store(stateClosed) + + // Try half-close on destination + if d.isUpload { + if d.connection.download != nil { + N.CloseWrite(d.connection.download.originSource) + } + } else { + if d.connection.upload != nil { + N.CloseWrite(d.connection.upload.originSource) + } + } + + d.connection.checkBothClosed() + return + } + + // Error: close entire connection + d.connection.reactor.logger.Trace("stream direction: error occurred: ", err) + d.closeWithError(err) +} + +func (d *streamDirection) closeWithError(err error) { + d.connection.closeWithError(err) +} + +func (c *streamConnection) checkBothClosed() { + uploadClosed := c.upload != nil && c.upload.state.Load() == stateClosed + downloadClosed := c.download != nil && c.download.state.Load() == stateClosed + + if uploadClosed && downloadClosed { + c.closeOnce.Do(func() { + defer close(c.done) + c.reactor.logger.Trace("stream connection: both directions closed gracefully") + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + + c.cancel() + c.removeFromPoller() + + common.Close(c.upload.originSource) + common.Close(c.download.originSource) + + if c.onClose != nil { + c.onClose(nil) + } + }) + } +} + +func (c *streamConnection) closeWithError(err error) { + c.closeOnce.Do(func() { + defer close(c.done) + c.reactor.logger.Trace("stream connection: closing with error: ", err) + + if c.stopReactorWatch != nil { + c.stopReactorWatch() + } + + c.err = err + c.cancel() + + if c.upload != nil { + c.upload.state.Store(stateClosed) + } + if c.download != nil { + c.download.state.Store(stateClosed) + } + + c.removeFromPoller() + + if c.upload != nil { + common.Close(c.upload.originSource) + } + if c.download != nil { + common.Close(c.download.originSource) + } + + if c.onClose != nil { + c.onClose(c.err) + } + }) +} + +func (c *streamConnection) removeFromPoller() { + if c.reactor.fdPoller == nil { + return + } + + if c.upload != nil && c.upload.pollable != nil { + c.reactor.fdPoller.Remove(c.upload.pollable.FD()) + } + if c.download != nil && c.download.pollable != nil { + c.reactor.fdPoller.Remove(c.download.pollable.FD()) + } +} diff --git a/common/bufio/stream_reactor_test.go b/common/bufio/stream_reactor_test.go new file mode 100644 index 00000000..f19b652c --- /dev/null +++ b/common/bufio/stream_reactor_test.go @@ -0,0 +1,902 @@ +//go:build darwin || linux || windows + +package bufio + +import ( + "context" + "crypto/md5" + "crypto/rand" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "testing" + "time" + + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fdConn wraps a net.Conn to implement StreamNotifier +type fdConn struct { + net.Conn + fd int +} + +func newFDConn(t *testing.T, conn net.Conn) *fdConn { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok, "connection must implement syscall.Conn") + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd int + err = rawConn.Control(func(f uintptr) { fd = int(f) }) + require.NoError(t, err) + return &fdConn{ + Conn: conn, + fd: fd, + } +} + +func (c *fdConn) FD() int { + return c.fd +} + +func (c *fdConn) Buffered() int { + return 0 +} + +// bufferedConn wraps a net.Conn with a buffer for testing StreamNotifier +type bufferedConn struct { + net.Conn + buffer *buf.Buffer + bufferMu sync.Mutex + fd int +} + +func newBufferedConn(t *testing.T, conn net.Conn) *bufferedConn { + bc := &bufferedConn{ + Conn: conn, + buffer: buf.New(), + } + if syscallConn, ok := conn.(syscall.Conn); ok { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + rawConn.Control(func(f uintptr) { bc.fd = int(f) }) + } + } + return bc +} + +func (c *bufferedConn) Read(p []byte) (n int, err error) { + c.bufferMu.Lock() + if c.buffer.Len() > 0 { + n = copy(p, c.buffer.Bytes()) + c.buffer.Advance(n) + c.bufferMu.Unlock() + return n, nil + } + c.bufferMu.Unlock() + return c.Conn.Read(p) +} + +func (c *bufferedConn) FD() int { + return c.fd +} + +func (c *bufferedConn) Buffered() int { + c.bufferMu.Lock() + defer c.bufferMu.Unlock() + return c.buffer.Len() +} + +func (c *bufferedConn) Close() error { + c.buffer.Release() + return c.Conn.Close() +} + +func TestStreamReactor_Basic(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + // Create a pair of connected TCP connections + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + var serverConn net.Conn + var serverErr error + serverDone := make(chan struct{}) + go func() { + serverConn, serverErr = listener.Accept() + close(serverDone) + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + <-serverDone + require.NoError(t, serverErr) + defer serverConn.Close() + + // Create another pair for the destination + listener2, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener2.Close() + + var destServerConn net.Conn + var destServerErr error + destServerDone := make(chan struct{}) + go func() { + destServerConn, destServerErr = listener2.Accept() + close(destServerDone) + }() + + destClientConn, err := net.Dial("tcp", listener2.Addr().String()) + require.NoError(t, err) + defer destClientConn.Close() + + <-destServerDone + require.NoError(t, destServerErr) + defer destServerConn.Close() + + // Test data transfer + testData := make([]byte, 1024) + rand.Read(testData) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, serverConn, destClientConn, func(err error) { + close(closeDone) + }) + + // Write from client to server, should pass through to dest + _, err = clientConn.Write(testData) + require.NoError(t, err) + + // Read from destServerConn + received := make([]byte, len(testData)) + _, err = io.ReadFull(destServerConn, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + // Test reverse direction + reverseData := make([]byte, 512) + rand.Read(reverseData) + + _, err = destServerConn.Write(reverseData) + require.NoError(t, err) + + reverseReceived := make([]byte, len(reverseData)) + _, err = io.ReadFull(clientConn, reverseReceived) + require.NoError(t, err) + assert.Equal(t, reverseData, reverseReceived) + + // Close and wait + clientConn.Close() + destServerConn.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_FDNotifier(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + // Create TCP connection pairs + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap with FD notifier + fdServer1 := newFDConn(t, server1) + fdClient2 := newFDConn(t, client2) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, fdServer1, fdClient2, func(err error) { + close(closeDone) + }) + + // Test data transfer + testData := make([]byte, 2048) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_BufferedReader(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Use buffered conn + bufferedServer1 := newBufferedConn(t, server1) + defer bufferedServer1.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, bufferedServer1, client2, func(err error) { + close(closeDone) + }) + + // Send data + testData := make([]byte, 1024) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_HalfClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send data in one direction + testData := make([]byte, 512) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + // Close client1's write side + if tcpConn, ok := client1.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } else { + client1.Close() + } + + // The other direction should still work for a moment + reverseData := make([]byte, 256) + rand.Read(reverseData) + + _, err = server2.Write(reverseData) + require.NoError(t, err) + + // Eventually both will close + server2.Close() + client1.Close() + + select { + case <-closeDone: + // closeErr should be nil for graceful close + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_MultipleConnections(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + const numConnections = 10 + var wg sync.WaitGroup + var completedCount atomic.Int32 + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send unique data + testData := make([]byte, 256) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, testData, received) + + client1.Close() + server2.Close() + + select { + case <-closeDone: + completedCount.Add(1) + case <-time.After(5 * time.Second): + t.Errorf("connection %d: timeout waiting for close callback", idx) + } + }(i) + } + + wg.Wait() + assert.Equal(t, int32(numConnections), completedCount.Load()) +} + +func TestStreamReactor_ReactorClose(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, client2, func(err error) { + close(closeDone) + }) + + // Send some data first + testData := make([]byte, 128) + rand.Read(testData) + + _, err := client1.Write(testData) + require.NoError(t, err) + + received := make([]byte, len(testData)) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + + // Close the reactor + reactor.Close() + + // Close connections + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback after reactor close") + } +} + +// Helper function to create a connected TCP pair +func createTCPPair(t *testing.T) (net.Conn, net.Conn) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + var serverConn net.Conn + var serverErr error + serverDone := make(chan struct{}) + go func() { + serverConn, serverErr = listener.Accept() + close(serverDone) + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + + <-serverDone + require.NoError(t, serverErr) + + return serverConn, clientConn +} + +// failingConn wraps net.Conn and fails writes after N calls +type failingConn struct { + net.Conn + failAfter int + writeCount atomic.Int32 +} + +func (c *failingConn) Write(p []byte) (int, error) { + if c.writeCount.Add(1) > int32(c.failAfter) { + return 0, errors.New("simulated write error") + } + return c.Conn.Write(p) +} + +// errorConn returns error on Read +type errorConn struct { + net.Conn + readError error +} + +func (c *errorConn) Read(p []byte) (int, error) { + return 0, c.readError +} + +// countingConn wraps net.Conn with read/write counters +type countingConn struct { + net.Conn + readCount atomic.Int64 + writeCount atomic.Int64 +} + +func (c *countingConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + c.readCount.Add(int64(n)) + return n, err +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + c.writeCount.Add(int64(n)) + return n, err +} + +func (c *countingConn) UnwrapReader() (io.Reader, []N.CountFunc) { + return c.Conn, []N.CountFunc{func(n int64) { c.readCount.Add(n) }} +} + +func (c *countingConn) UnwrapWriter() (io.Writer, []N.CountFunc) { + return c.Conn, []N.CountFunc{func(n int64) { c.writeCount.Add(n) }} +} + +func TestStreamReactor_WriteError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap destination with failing writer that fails after 3 writes + failingDest := &failingConn{Conn: client2, failAfter: 3} + + var capturedErr error + closeDone := make(chan struct{}) + reactor.Copy(ctx, server1, failingDest, func(err error) { + capturedErr = err + close(closeDone) + }) + + // Send multiple chunks of data to trigger the write failure + testData := make([]byte, 1024) + rand.Read(testData) + + for i := 0; i < 10; i++ { + _, err := client1.Write(testData) + if err != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Close source to trigger cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + // Verify error was propagated + assert.NotNil(t, capturedErr, "expected error to be propagated") + assert.Contains(t, capturedErr.Error(), "simulated write error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_ReadError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap source with error-returning reader + readErr := errors.New("simulated read error") + errorSrc := &errorConn{Conn: server1, readError: readErr} + + var capturedErr error + closeDone := make(chan struct{}) + reactor.Copy(ctx, errorSrc, client2, func(err error) { + capturedErr = err + close(closeDone) + }) + + select { + case <-closeDone: + // Verify error was propagated + assert.NotNil(t, capturedErr, "expected error to be propagated") + assert.Contains(t, capturedErr.Error(), "simulated read error") + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_Counters(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Wrap with counting connections + countingSrc := &countingConn{Conn: server1} + countingDst := &countingConn{Conn: client2} + + closeDone := make(chan struct{}) + reactor.Copy(ctx, countingSrc, countingDst, func(err error) { + close(closeDone) + }) + + // Send data in both directions + const dataSize = 4096 + uploadData := make([]byte, dataSize) + downloadData := make([]byte, dataSize) + rand.Read(uploadData) + rand.Read(downloadData) + + // Upload: client1 -> server1 -> client2 -> server2 + _, err := client1.Write(uploadData) + require.NoError(t, err) + + received := make([]byte, dataSize) + _, err = io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, uploadData, received) + + // Download: server2 -> client2 -> server1 -> client1 + _, err = server2.Write(downloadData) + require.NoError(t, err) + + received2 := make([]byte, dataSize) + _, err = io.ReadFull(client1, received2) + require.NoError(t, err) + assert.Equal(t, downloadData, received2) + + // Close connections + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } + + // Verify counters (read from source, write to destination) + // Note: The countingConn tracks actual reads/writes at the connection level + assert.True(t, countingSrc.readCount.Load() >= dataSize, "source should have read at least %d bytes, got %d", dataSize, countingSrc.readCount.Load()) + assert.True(t, countingDst.writeCount.Load() >= dataSize, "destination should have written at least %d bytes, got %d", dataSize, countingDst.writeCount.Load()) +} + +func TestStreamReactor_CachedReader(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Create cached data + cachedData := make([]byte, 512) + rand.Read(cachedData) + cachedBuffer := buf.As(cachedData) + + // Wrap source with cached conn + cachedSrc := NewCachedConn(server1, cachedBuffer) + defer cachedSrc.Close() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, cachedSrc, client2, func(err error) { + close(closeDone) + }) + + // The cached data should be sent first before any new data + // Read cached data from destination + receivedCached := make([]byte, len(cachedData)) + _, err := io.ReadFull(server2, receivedCached) + require.NoError(t, err) + assert.Equal(t, cachedData, receivedCached, "cached data should be received first") + + // Now send new data through the connection + newData := make([]byte, 256) + rand.Read(newData) + + _, err = client1.Write(newData) + require.NoError(t, err) + + receivedNew := make([]byte, len(newData)) + _, err = io.ReadFull(server2, receivedNew) + require.NoError(t, err) + assert.Equal(t, newData, receivedNew, "new data should be received after cached data") + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_LargeData(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + fdServer1 := newFDConn(t, server1) + fdClient2 := newFDConn(t, client2) + + closeDone := make(chan struct{}) + reactor.Copy(ctx, fdServer1, fdClient2, func(err error) { + close(closeDone) + }) + + // Test with 10MB of data + const dataSize = 10 * 1024 * 1024 + uploadData := make([]byte, dataSize) + rand.Read(uploadData) + uploadHash := md5.Sum(uploadData) + + downloadData := make([]byte, dataSize) + rand.Read(downloadData) + downloadHash := md5.Sum(downloadData) + + var wg sync.WaitGroup + errChan := make(chan error, 4) + + // Upload goroutine + wg.Add(1) + go func() { + defer wg.Done() + _, err := client1.Write(uploadData) + if err != nil { + errChan <- err + } + }() + + // Upload receiver goroutine + wg.Add(1) + go func() { + defer wg.Done() + received := make([]byte, dataSize) + _, err := io.ReadFull(server2, received) + if err != nil { + errChan <- err + return + } + receivedHash := md5.Sum(received) + if receivedHash != uploadHash { + errChan <- errors.New("upload data mismatch") + } + }() + + // Download goroutine + wg.Add(1) + go func() { + defer wg.Done() + _, err := server2.Write(downloadData) + if err != nil { + errChan <- err + } + }() + + // Download receiver goroutine + wg.Add(1) + go func() { + defer wg.Done() + received := make([]byte, dataSize) + _, err := io.ReadFull(client1, received) + if err != nil { + errChan <- err + return + } + receivedHash := md5.Sum(received) + if receivedHash != downloadHash { + errChan <- errors.New("download data mismatch") + } + }() + + // Wait for completion + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case err := <-errChan: + t.Fatalf("transfer error: %v", err) + case <-done: + // Success + case <-time.After(60 * time.Second): + t.Fatal("timeout during large data transfer") + } + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} + +func TestStreamReactor_BufferedAtRegistration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + reactor := NewStreamReactor(ctx, nil) + defer reactor.Close() + + server1, client1 := createTCPPair(t) + defer server1.Close() + defer client1.Close() + + server2, client2 := createTCPPair(t) + defer server2.Close() + defer client2.Close() + + // Create buffered conn with pre-populated buffer + bufferedServer1 := newBufferedConn(t, server1) + defer bufferedServer1.Close() + + // Pre-populate the buffer with data + preBufferedData := []byte("pre-buffered data that should be sent immediately") + bufferedServer1.bufferMu.Lock() + bufferedServer1.buffer.Write(preBufferedData) + bufferedServer1.bufferMu.Unlock() + + closeDone := make(chan struct{}) + reactor.Copy(ctx, bufferedServer1, client2, func(err error) { + close(closeDone) + }) + + // The pre-buffered data should be processed immediately + received := make([]byte, len(preBufferedData)) + server2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err := io.ReadFull(server2, received) + require.NoError(t, err) + assert.Equal(t, preBufferedData, received, "pre-buffered data should be received") + + // Now send additional data + additionalData := []byte("additional data") + _, err = client1.Write(additionalData) + require.NoError(t, err) + + additionalReceived := make([]byte, len(additionalData)) + server2.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = io.ReadFull(server2, additionalReceived) + require.NoError(t, err) + assert.Equal(t, additionalData, additionalReceived) + + // Cleanup + client1.Close() + server2.Close() + + select { + case <-closeDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for close callback") + } +} diff --git a/common/context_afterfunc.go b/common/context_afterfunc.go new file mode 100644 index 00000000..887dc148 --- /dev/null +++ b/common/context_afterfunc.go @@ -0,0 +1,11 @@ +//go:build go1.21 + +package common + +import "context" + +// ContextAfterFunc arranges to call f in its own goroutine after ctx is done. +// Returns a stop function that prevents f from being run. +func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) { + return context.AfterFunc(ctx, f) +} diff --git a/common/context_afterfunc_compat.go b/common/context_afterfunc_compat.go new file mode 100644 index 00000000..0a09ced7 --- /dev/null +++ b/common/context_afterfunc_compat.go @@ -0,0 +1,41 @@ +//go:build go1.20 && !go1.21 + +package common + +import ( + "context" + "sync" +) + +// ContextAfterFunc arranges to call f in its own goroutine after ctx is done. +// Returns a stop function that prevents f from being run. +func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) { + stopCh := make(chan struct{}) + var once sync.Once + stopped := false + + go func() { + select { + case <-ctx.Done(): + once.Do(func() { + if !stopped { + f() + } + }) + case <-stopCh: + } + }() + + return func() bool { + select { + case <-ctx.Done(): + return false + default: + stopped = true + once.Do(func() { + close(stopCh) + }) + return true + } + } +} diff --git a/common/network/handshake.go b/common/network/handshake.go index c86a1b57..a032ab6f 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -30,7 +30,7 @@ func ReportHandshakeFailure(reporter any, err error) error { return E.Cause(err, "write handshake failure") }) } - return nil + return err } func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error { diff --git a/common/network/packet_pollable.go b/common/network/packet_pollable.go new file mode 100644 index 00000000..80fb9f65 --- /dev/null +++ b/common/network/packet_pollable.go @@ -0,0 +1,19 @@ +package network + +// PacketPushable represents a packet source that receives pushed data +// from external code and notifies reactor via callback. +type PacketPushable interface { + SetOnDataReady(callback func()) + HasPendingData() bool +} + +// PacketPollable provides FD-based polling for packet connections. +// Mirrors StreamPollable for consistency. +type PacketPollable interface { + FD() int +} + +// PacketPollableCreator creates a PacketPollable dynamically. +type PacketPollableCreator interface { + CreatePacketPollable() (PacketPollable, bool) +} diff --git a/common/network/stream_pollable.go b/common/network/stream_pollable.go new file mode 100644 index 00000000..6dfd58da --- /dev/null +++ b/common/network/stream_pollable.go @@ -0,0 +1,17 @@ +package network + +// StreamPollable provides reactor support for TCP stream connections +// Used by StreamReactor for idle detection via epoll/kqueue/IOCP +type StreamPollable interface { + // FD returns the file descriptor for reactor registration + FD() int + // Buffered returns the number of bytes in internal buffer + // Reactor must check this before returning to idle state + Buffered() int +} + +// StreamPollableCreator creates a StreamPollable dynamically +// Optional interface - prefer direct implementation of StreamPollable +type StreamPollableCreator interface { + CreateStreamPollable() (StreamPollable, bool) +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 3c0cda38..57260f64 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -13,7 +13,6 @@ import ( "github.com/sagernet/sing/common/canceler" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/contrab/freelru" ) @@ -23,7 +22,11 @@ type Conn interface { canceler.PacketConn } -var _ Conn = (*natConn)(nil) +var ( + _ Conn = (*natConn)(nil) + _ N.PacketPushable = (*natConn)(nil) + _ N.PacketReadWaiter = (*natConn)(nil) +) type natConn struct { cache freelru.Cache[netip.AddrPort, *natConn] @@ -31,25 +34,50 @@ type natConn struct { localAddr M.Socksaddr handlerAccess sync.RWMutex handler N.UDPHandlerEx - packetChan chan *N.PacketBuffer - closeOnce sync.Once - doneChan chan struct{} - readDeadline pipe.Deadline readWaitOptions N.ReadWaitOptions + + dataQueue []*N.PacketBuffer + queueMutex sync.Mutex + onDataReady func() + + deadlineMutex sync.Mutex + deadlineTimer *time.Timer + deadlineChan chan struct{} + dataSignal chan struct{} + + closeOnce sync.Once + doneChan chan struct{} } -func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { - select { - case p := <-c.packetChan: - _, err = buffer.ReadOnceFrom(p.Buffer) - destination := p.Destination - p.Buffer.Release() - N.PutPacketBuffer(p) - return destination, err - case <-c.doneChan: - return M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): - return M.Socksaddr{}, os.ErrDeadlineExceeded +func (c *natConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + for { + select { + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + default: + } + + c.queueMutex.Lock() + if len(c.dataQueue) > 0 { + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + _, err = buffer.ReadOnceFrom(packet.Buffer) + destination = packet.Destination + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + } + c.queueMutex.Unlock() + + select { + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + case <-c.waitDeadline(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + case <-c.dataSignal: + continue + } } } @@ -65,16 +93,33 @@ func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool } func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - select { - case packet := <-c.packetChan: - buffer = c.readWaitOptions.Copy(packet.Buffer) - destination = packet.Destination - N.PutPacketBuffer(packet) - return - case <-c.doneChan: - return nil, M.Socksaddr{}, io.ErrClosedPipe - case <-c.readDeadline.Wait(): - return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + for { + select { + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + default: + } + + c.queueMutex.Lock() + if len(c.dataQueue) > 0 { + packet := c.dataQueue[0] + c.dataQueue = c.dataQueue[1:] + c.queueMutex.Unlock() + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + N.PutPacketBuffer(packet) + return + } + c.queueMutex.Unlock() + + select { + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.waitDeadline(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + case <-c.dataSignal: + continue + } } } @@ -83,16 +128,49 @@ func (c *natConn) SetHandler(handler N.UDPHandlerEx) { c.handler = handler c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler) c.handlerAccess.Unlock() -fetch: - for { - select { - case packet := <-c.packetChan: - c.handler.NewPacketEx(packet.Buffer, packet.Destination) - N.PutPacketBuffer(packet) - continue fetch - default: - break fetch - } + + c.queueMutex.Lock() + pending := c.dataQueue + c.dataQueue = nil + c.queueMutex.Unlock() + + for _, packet := range pending { + handler.NewPacketEx(packet.Buffer, packet.Destination) + N.PutPacketBuffer(packet) + } +} + +func (c *natConn) SetOnDataReady(callback func()) { + c.queueMutex.Lock() + c.onDataReady = callback + c.queueMutex.Unlock() +} + +func (c *natConn) HasPendingData() bool { + c.queueMutex.Lock() + defer c.queueMutex.Unlock() + return len(c.dataQueue) > 0 +} + +func (c *natConn) PushPacket(packet *N.PacketBuffer) { + c.queueMutex.Lock() + if len(c.dataQueue) >= 64 { + c.queueMutex.Unlock() + packet.Buffer.Release() + N.PutPacketBuffer(packet) + return + } + c.dataQueue = append(c.dataQueue, packet) + callback := c.onDataReady + c.queueMutex.Unlock() + + select { + case c.dataSignal <- struct{}{}: + default: + } + + if callback != nil { + callback() } } @@ -111,6 +189,14 @@ func (c *natConn) SetTimeout(timeout time.Duration) bool { func (c *natConn) Close() error { c.closeOnce.Do(func() { close(c.doneChan) + + c.queueMutex.Lock() + pending := c.dataQueue + c.dataQueue = nil + c.onDataReady = nil + c.queueMutex.Unlock() + + N.ReleaseMultiPacketBuffer(pending) common.Close(c.handler) }) return nil @@ -129,10 +215,52 @@ func (c *natConn) SetDeadline(t time.Time) error { } func (c *natConn) SetReadDeadline(t time.Time) error { - c.readDeadline.Set(t) + c.deadlineMutex.Lock() + defer c.deadlineMutex.Unlock() + + if c.deadlineTimer != nil && !c.deadlineTimer.Stop() { + <-c.deadlineChan + } + c.deadlineTimer = nil + + if t.IsZero() { + if isClosedChan(c.deadlineChan) { + c.deadlineChan = make(chan struct{}) + } + return nil + } + + if duration := time.Until(t); duration > 0 { + if isClosedChan(c.deadlineChan) { + c.deadlineChan = make(chan struct{}) + } + c.deadlineTimer = time.AfterFunc(duration, func() { + close(c.deadlineChan) + }) + return nil + } + + if !isClosedChan(c.deadlineChan) { + close(c.deadlineChan) + } return nil } +func (c *natConn) waitDeadline() chan struct{} { + c.deadlineMutex.Lock() + defer c.deadlineMutex.Unlock() + return c.deadlineChan +} + +func isClosedChan(channel <-chan struct{}) bool { + select { + case <-channel: + return true + default: + return false + } +} + func (c *natConn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 3e3ce7d1..4370651f 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -8,7 +8,6 @@ import ( "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/pipe" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" ) @@ -60,9 +59,9 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati cache: s.cache, writer: writer, localAddr: source, - packetChan: make(chan *N.PacketBuffer, 64), + deadlineChan: make(chan struct{}), + dataSignal: make(chan struct{}, 1), doneChan: make(chan struct{}), - readDeadline: pipe.MakeDeadline(), } go s.handler.NewPacketConnectionEx(ctx, newConn, source, destination, onClose) return newConn, true @@ -87,12 +86,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati Buffer: buffer, Destination: destination, } - select { - case conn.packetChan <- packet: - default: - packet.Buffer.Release() - N.PutPacketBuffer(packet) - } + conn.PushPacket(packet) } func (s *Service) Purge() { diff --git a/common/wepoll/afd_windows.go b/common/wepoll/afd_windows.go new file mode 100644 index 00000000..1547cc8b --- /dev/null +++ b/common/wepoll/afd_windows.go @@ -0,0 +1,123 @@ +//go:build windows + +package wepoll + +import ( + "math" + "unsafe" + + "golang.org/x/sys/windows" +) + +type AFD struct { + handle windows.Handle +} + +func NewAFD(iocp windows.Handle, name string) (*AFD, error) { + deviceName := `\Device\Afd\` + name + deviceNameUTF16, err := windows.UTF16FromString(deviceName) + if err != nil { + return nil, err + } + + unicodeString := UnicodeString{ + Length: uint16(len(deviceName) * 2), + MaximumLength: uint16(len(deviceName) * 2), + Buffer: &deviceNameUTF16[0], + } + + objectAttributes := ObjectAttributes{ + Length: uint32(unsafe.Sizeof(ObjectAttributes{})), + ObjectName: &unicodeString, + Attributes: OBJ_CASE_INSENSITIVE, + } + + var handle windows.Handle + var ioStatusBlock windows.IO_STATUS_BLOCK + + err = NtCreateFile( + &handle, + windows.SYNCHRONIZE, + &objectAttributes, + &ioStatusBlock, + nil, + 0, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, + FILE_OPEN, + 0, + 0, + 0, + ) + if err != nil { + return nil, err + } + + _, err = windows.CreateIoCompletionPort(handle, iocp, 0, 0) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + err = windows.SetFileCompletionNotificationModes(handle, windows.FILE_SKIP_SET_EVENT_ON_HANDLE) + if err != nil { + windows.CloseHandle(handle) + return nil, err + } + + return &AFD{handle: handle}, nil +} + +func (a *AFD) Poll(baseSocket windows.Handle, events uint32, iosb *windows.IO_STATUS_BLOCK, pollInfo *AFDPollInfo) error { + pollInfo.Timeout = math.MaxInt64 + pollInfo.NumberOfHandles = 1 + pollInfo.Exclusive = 0 + pollInfo.Handles[0].Handle = baseSocket + pollInfo.Handles[0].Events = events + pollInfo.Handles[0].Status = 0 + + iosb.Status = windows.NTStatus(STATUS_PENDING) + size := uint32(unsafe.Sizeof(*pollInfo)) + + err := NtDeviceIoControlFile( + a.handle, + 0, + 0, + uintptr(unsafe.Pointer(iosb)), + iosb, + IOCTL_AFD_POLL, + unsafe.Pointer(pollInfo), + size, + unsafe.Pointer(pollInfo), + size, + ) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_PENDING { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Cancel(ioStatusBlock *windows.IO_STATUS_BLOCK) error { + if uint32(ioStatusBlock.Status) != STATUS_PENDING { + return nil + } + var cancelIOStatusBlock windows.IO_STATUS_BLOCK + err := NtCancelIoFileEx(a.handle, ioStatusBlock, &cancelIOStatusBlock) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + if uint32(ntstatus) == STATUS_CANCELLED || uint32(ntstatus) == STATUS_NOT_FOUND { + return nil + } + } + return err + } + return nil +} + +func (a *AFD) Close() error { + return windows.CloseHandle(a.handle) +} diff --git a/common/wepoll/pinner.go b/common/wepoll/pinner.go new file mode 100644 index 00000000..58b76686 --- /dev/null +++ b/common/wepoll/pinner.go @@ -0,0 +1,7 @@ +//go:build go1.21 + +package wepoll + +import "runtime" + +type Pinner = runtime.Pinner diff --git a/common/wepoll/pinner_compat.go b/common/wepoll/pinner_compat.go new file mode 100644 index 00000000..a51a9fa6 --- /dev/null +++ b/common/wepoll/pinner_compat.go @@ -0,0 +1,9 @@ +//go:build !go1.21 + +package wepoll + +type Pinner struct{} + +func (p *Pinner) Pin(pointer any) {} + +func (p *Pinner) Unpin() {} diff --git a/common/wepoll/socket_windows.go b/common/wepoll/socket_windows.go new file mode 100644 index 00000000..e5655990 --- /dev/null +++ b/common/wepoll/socket_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package wepoll + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +func GetBaseSocket(socket windows.Handle) (windows.Handle, error) { + var baseSocket windows.Handle + var bytesReturned uint32 + + for { + err := windows.WSAIoctl( + socket, + SIO_BASE_HANDLE, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + err = windows.WSAIoctl( + socket, + SIO_BSP_HANDLE_POLL, + nil, + 0, + (*byte)(unsafe.Pointer(&baseSocket)), + uint32(unsafe.Sizeof(baseSocket)), + &bytesReturned, + nil, + 0, + ) + if err != nil { + return socket, nil + } + } + + if baseSocket == socket { + return baseSocket, nil + } + socket = baseSocket + } +} diff --git a/common/wepoll/syscall_windows.go b/common/wepoll/syscall_windows.go new file mode 100644 index 00000000..948dd00e --- /dev/null +++ b/common/wepoll/syscall_windows.go @@ -0,0 +1,8 @@ +package wepoll + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +//sys NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) = ntdll.NtCreateFile +//sys NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) = ntdll.NtDeviceIoControlFile +//sys NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) = ntdll.NtCancelIoFileEx +//sys GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) = kernel32.GetQueuedCompletionStatusEx diff --git a/common/wepoll/types_windows.go b/common/wepoll/types_windows.go new file mode 100644 index 00000000..aad2d79a --- /dev/null +++ b/common/wepoll/types_windows.go @@ -0,0 +1,64 @@ +//go:build windows + +package wepoll + +import "golang.org/x/sys/windows" + +const ( + IOCTL_AFD_POLL = 0x00012024 + + AFD_POLL_RECEIVE = 0x0001 + AFD_POLL_RECEIVE_EXPEDITED = 0x0002 + AFD_POLL_SEND = 0x0004 + AFD_POLL_DISCONNECT = 0x0008 + AFD_POLL_ABORT = 0x0010 + AFD_POLL_LOCAL_CLOSE = 0x0020 + AFD_POLL_ACCEPT = 0x0080 + AFD_POLL_CONNECT_FAIL = 0x0100 + + SIO_BASE_HANDLE = 0x48000022 + SIO_BSP_HANDLE_POLL = 0x4800001D + + STATUS_PENDING = 0x00000103 + STATUS_CANCELLED = 0xC0000120 + STATUS_NOT_FOUND = 0xC0000225 + + FILE_OPEN = 0x00000001 + + OBJ_CASE_INSENSITIVE = 0x00000040 +) + +type AFDPollHandleInfo struct { + Handle windows.Handle + Events uint32 + Status uint32 +} + +type AFDPollInfo struct { + Timeout int64 + NumberOfHandles uint32 + Exclusive uint32 + Handles [1]AFDPollHandleInfo +} + +type OverlappedEntry struct { + CompletionKey uintptr + Overlapped *windows.Overlapped + Internal uintptr + NumberOfBytesTransferred uint32 +} + +type UnicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +type ObjectAttributes struct { + Length uint32 + RootDirectory windows.Handle + ObjectName *UnicodeString + Attributes uint32 + SecurityDescriptor uintptr + SecurityQualityOfService uintptr +} diff --git a/common/wepoll/wepoll_test.go b/common/wepoll/wepoll_test.go new file mode 100644 index 00000000..0a0ec023 --- /dev/null +++ b/common/wepoll/wepoll_test.go @@ -0,0 +1,335 @@ +//go:build windows + +package wepoll + +import ( + "net" + "syscall" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func createTestIOCP(t *testing.T) windows.Handle { + iocp, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + require.NoError(t, err) + t.Cleanup(func() { + windows.CloseHandle(iocp) + }) + return iocp +} + +func getSocketHandle(t *testing.T, conn net.PacketConn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func getTCPSocketHandle(t *testing.T, conn net.Conn) windows.Handle { + syscallConn, ok := conn.(syscall.Conn) + require.True(t, ok) + rawConn, err := syscallConn.SyscallConn() + require.NoError(t, err) + var fd uintptr + err = rawConn.Control(func(f uintptr) { fd = f }) + require.NoError(t, err) + return windows.Handle(fd) +} + +func TestNewAFD(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "test") + require.NoError(t, err) + require.NotNil(t, afd) + + err = afd.Close() + require.NoError(t, err) +} + +func TestNewAFD_MultipleTimes(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd1, err := NewAFD(iocp, "test1") + require.NoError(t, err) + defer afd1.Close() + + afd2, err := NewAFD(iocp, "test2") + require.NoError(t, err) + defer afd2.Close() + + afd3, err := NewAFD(iocp, "test3") + require.NoError(t, err) + defer afd3.Close() +} + +func TestGetBaseSocket_UDP(t *testing.T) { + t.Parallel() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestGetBaseSocket_TCP(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + go func() { + conn, err := listener.Accept() + if err == nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + handle := getTCPSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + require.NotEqual(t, windows.InvalidHandle, baseHandle) +} + +func TestAFD_Poll_ReceiveEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "poll_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + _, err = sender.WriteTo([]byte("test data"), conn.LocalAddr()) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + require.Equal(t, uintptr(unsafe.Pointer(&state.iosb)), uintptr(unsafe.Pointer(entries[0].Overlapped))) +} + +func TestAFD_Cancel(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "cancel_test") + require.NoError(t, err) + defer afd.Close() + + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + err = afd.Cancel(&state.iosb) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 1000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) +} + +func TestAFD_Close(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "close_test") + require.NoError(t, err) + + err = afd.Close() + require.NoError(t, err) +} + +func TestGetQueuedCompletionStatusEx_Timeout(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + + start := time.Now() + err := GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 100, false) + elapsed := time.Since(start) + + require.Error(t, err) + require.GreaterOrEqual(t, elapsed, 50*time.Millisecond) +} + +func TestGetQueuedCompletionStatusEx_MultipleEntries(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "multi_test") + require.NoError(t, err) + defer afd.Close() + + const numConns = 3 + conns := make([]net.PacketConn, numConns) + states := make([]struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + }, numConns) + pinners := make([]Pinner, numConns) + + for i := 0; i < numConns; i++ { + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer conn.Close() + conns[i] = conn + + handle := getSocketHandle(t, conn) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + pinners[i].Pin(&states[i]) + defer pinners[i].Unpin() + + events := uint32(AFD_POLL_RECEIVE) + err = afd.Poll(baseHandle, events, &states[i].iosb, &states[i].pollInfo) + require.NoError(t, err) + } + + sender, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer sender.Close() + + for i := 0; i < numConns; i++ { + _, err = sender.WriteTo([]byte("test"), conns[i].LocalAddr()) + require.NoError(t, err) + } + + entries := make([]OverlappedEntry, 8) + var numRemoved uint32 + received := 0 + for received < numConns { + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 8, &numRemoved, 5000, false) + require.NoError(t, err) + received += int(numRemoved) + } + require.Equal(t, numConns, received) +} + +func TestAFD_Poll_DisconnectEvent(t *testing.T) { + t.Parallel() + + iocp := createTestIOCP(t) + + afd, err := NewAFD(iocp, "disconnect_test") + require.NoError(t, err) + defer afd.Close() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + conn, err := listener.Accept() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + conn.Close() + }() + + client, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer client.Close() + + handle := getTCPSocketHandle(t, client) + baseHandle, err := GetBaseSocket(handle) + require.NoError(t, err) + + var state struct { + iosb windows.IO_STATUS_BLOCK + pollInfo AFDPollInfo + } + + var pinner Pinner + pinner.Pin(&state) + defer pinner.Unpin() + + events := uint32(AFD_POLL_RECEIVE | AFD_POLL_DISCONNECT | AFD_POLL_ABORT) + err = afd.Poll(baseHandle, events, &state.iosb, &state.pollInfo) + require.NoError(t, err) + + entries := make([]OverlappedEntry, 1) + var numRemoved uint32 + err = GetQueuedCompletionStatusEx(iocp, &entries[0], 1, &numRemoved, 5000, false) + require.NoError(t, err) + require.Equal(t, uint32(1), numRemoved) + + <-serverDone +} diff --git a/common/wepoll/zsyscall_windows.go b/common/wepoll/zsyscall_windows.go new file mode 100644 index 00000000..dac75d17 --- /dev/null +++ b/common/wepoll/zsyscall_windows.go @@ -0,0 +1,84 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package wepoll + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + + procGetQueuedCompletionStatusEx = modkernel32.NewProc("GetQueuedCompletionStatusEx") + procNtCancelIoFileEx = modntdll.NewProc("NtCancelIoFileEx") + procNtCreateFile = modntdll.NewProc("NtCreateFile") + procNtDeviceIoControlFile = modntdll.NewProc("NtDeviceIoControlFile") +) + +func GetQueuedCompletionStatusEx(cphandle windows.Handle, entries *OverlappedEntry, count uint32, numRemoved *uint32, timeout uint32, alertable bool) (err error) { + var _p0 uint32 + if alertable { + _p0 = 1 + } + r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatusEx.Addr(), 6, uintptr(cphandle), uintptr(unsafe.Pointer(entries)), uintptr(count), uintptr(unsafe.Pointer(numRemoved)), uintptr(timeout), uintptr(_p0)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func NtCancelIoFileEx(handle windows.Handle, ioRequestToCancel *windows.IO_STATUS_BLOCK, ioStatusBlock *windows.IO_STATUS_BLOCK) (ntstatus error) { + r0, _, _ := syscall.Syscall(procNtCancelIoFileEx.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(ioRequestToCancel)), uintptr(unsafe.Pointer(ioStatusBlock))) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtCreateFile(handle *windows.Handle, access uint32, oa *ObjectAttributes, iosb *windows.IO_STATUS_BLOCK, allocationSize *int64, attributes uint32, share uint32, disposition uint32, options uint32, eaBuffer uintptr, eaLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtCreateFile.Addr(), 11, uintptr(unsafe.Pointer(handle)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(unsafe.Pointer(allocationSize)), uintptr(attributes), uintptr(share), uintptr(disposition), uintptr(options), uintptr(eaBuffer), uintptr(eaLength), 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +} + +func NtDeviceIoControlFile(handle windows.Handle, event windows.Handle, apcRoutine uintptr, apcContext uintptr, ioStatusBlock *windows.IO_STATUS_BLOCK, ioControlCode uint32, inputBuffer unsafe.Pointer, inputBufferLength uint32, outputBuffer unsafe.Pointer, outputBufferLength uint32) (ntstatus error) { + r0, _, _ := syscall.Syscall12(procNtDeviceIoControlFile.Addr(), 10, uintptr(handle), uintptr(event), uintptr(apcRoutine), uintptr(apcContext), uintptr(unsafe.Pointer(ioStatusBlock)), uintptr(ioControlCode), uintptr(inputBuffer), uintptr(inputBufferLength), uintptr(outputBuffer), uintptr(outputBufferLength), 0, 0) + if r0 != 0 { + ntstatus = windows.NTStatus(r0) + } + return +}