diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 5c99af50f..192e34991 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -5,7 +5,6 @@ import ( "errors" "io" "net" - "syscall" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -59,14 +58,10 @@ func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBuf } func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64, batchSize int) (n int64, err error) { - srcSyscallConn, srcIsSyscall := source.(syscall.Conn) - dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) - if srcIsSyscall && dstIsSyscall { - var handled bool - handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) - if handled { - return - } + var handled bool + handled, n, err = copyDirect(source, destination, readCounters, writeCounters) + if handled { + return } return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter, batchSize) } diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go index bb4d7834d..7cd834122 100644 --- a/common/bufio/copy_direct.go +++ b/common/bufio/copy_direct.go @@ -3,23 +3,22 @@ package bufio import ( "errors" "io" - "syscall" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { - rawSource, err := source.SyscallConn() - if err != nil { +func copyDirect(source io.Reader, destination io.Writer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { + if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) { return } - rawDestination, err := destination.SyscallConn() - if err != nil { + sourceReader, sourceConn := N.SyscallConnForRead(source) + destinationWriter, destinationConn := N.SyscallConnForWrite(destination) + if sourceConn == nil || destinationConn == nil { return } - handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters) + handed, n, err = splice(sourceConn, sourceReader, destinationConn, destinationWriter, readCounters, writeCounters) return } diff --git a/common/bufio/splice_linux.go b/common/bufio/splice_linux.go index 1b08898ad..501543408 100644 --- a/common/bufio/splice_linux.go +++ b/common/bufio/splice_linux.go @@ -11,7 +11,7 @@ import ( const maxSpliceSize = 1 << 20 -func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { handed = true var pipeFDs [2]int err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK) @@ -20,12 +20,14 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } defer unix.Close(pipeFDs[0]) defer unix.Close(pipeFDs[1]) - _, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize) - var readN int - var readErr error - var writeSize int - var writeErr error + var ( + readN int + readErr error + writeSize int + writeErr error + notFirstTime bool + ) readFunc := func(fd uintptr) (done bool) { p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK) readN = int(p0) @@ -46,15 +48,28 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } for { err = source.Read(readFunc) - if err != nil { - readErr = err - } if readErr != nil { - if readErr == unix.EINVAL || readErr == unix.ENOSYS { + err = readErr + } + if err != nil { + if sourceReader != nil { + newBuffer, newErr := sourceReader.HandleSyscallReadError(err) + if newErr != nil { + err = newErr + } else { + err = nil + if len(newBuffer) > 0 { + readN, readErr = unix.Write(pipeFDs[1], newBuffer) + if readErr != nil { + err = E.Cause(err, "write handled data") + } + } + } + } else if !notFirstTime && E.IsMulti(err, unix.EINVAL, unix.ENOSYS) { handed = false return } - err = E.Cause(readErr, "splice read") + err = E.Cause(err, "splice read") return } if readN == 0 { @@ -62,18 +77,20 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } writeSize = readN err = destination.Write(writeFunc) - if err != nil { - writeErr = err - } if writeErr != nil { - err = E.Cause(writeErr, "splice write") + err = writeErr + } + if err != nil { + err = E.Cause(err, "splice write") return } + n += int64(readN) for _, readCounter := range readCounters { readCounter(int64(readN)) } for _, writeCounter := range writeCounters { writeCounter(int64(readN)) } + notFirstTime = true } } diff --git a/common/bufio/splice_stub.go b/common/bufio/splice_stub.go index 44c93b5ab..b120284b8 100644 --- a/common/bufio/splice_stub.go +++ b/common/bufio/splice_stub.go @@ -8,6 +8,6 @@ import ( N "github.com/sagernet/sing/common/network" ) -func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { return } diff --git a/common/bufio/wait.go b/common/bufio/wait.go index c32fa3039..6ac0821c7 100644 --- a/common/bufio/wait.go +++ b/common/bufio/wait.go @@ -3,11 +3,11 @@ package bufio import ( "io" + "github.com/sagernet/sing/common" N "github.com/sagernet/sing/common/network" ) func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { - reader = N.UnwrapReader(reader) if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter { return readWaiter, true } @@ -17,11 +17,19 @@ func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { if readWaiter, created := createSyscallReadWaiter(reader); created { return readWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateReadWaiter(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateReadWaiter(u.Upstream().(io.Reader)) + } return nil, false } func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool) { - reader = N.UnwrapReader(reader) if vectorisedReadWaiter, isVectorised := reader.(N.VectorisedReadWaiter); isVectorised { return vectorisedReadWaiter, true } @@ -31,11 +39,19 @@ func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool) if vectorisedReadWaiter, created := createVectorisedSyscallReadWaiter(reader); created { return vectorisedReadWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateVectorisedReadWaiter(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateVectorisedReadWaiter(u.Upstream().(io.Reader)) + } return nil, false } func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { - reader = N.UnwrapPacketReader(reader) if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter { return readWaiter, true } @@ -45,6 +61,15 @@ func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { if readWaiter, created := createSyscallPacketReadWaiter(reader); created { return readWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreatePacketReadWaiter(u.UpstreamReader().(N.PacketReader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreatePacketReadWaiter(u.Upstream().(N.PacketReader)) + } return nil, false } diff --git a/common/network/counter.go b/common/network/counter.go index a20c4d9d1..dc15f3204 100644 --- a/common/network/counter.go +++ b/common/network/counter.go @@ -2,6 +2,9 @@ package network import ( "io" + "syscall" + + "github.com/sagernet/sing/common" ) type CountFunc func(n int64) @@ -27,32 +30,65 @@ type PacketWriteCounter interface { } func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []CountFunc) { - reader = UnwrapReader(reader) if counter, isCounter := reader.(ReadCounter); isCounter { upstreamReader, upstreamCountFunc := counter.UnwrapReader() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountReader(upstreamReader, countFunc) } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return reader, countFunc + } + switch u := reader.(type) { + case ReadWaiter, ReadWaitCreator, syscall.Conn, SyscallReader: + // In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter + return reader, countFunc + case WithUpstreamReader: + return UnwrapCountReader(u.UpstreamReader().(io.Reader), countFunc) + case common.WithUpstream: + return UnwrapCountReader(u.Upstream().(io.Reader), countFunc) + } return reader, countFunc } func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []CountFunc) { - writer = UnwrapWriter(writer) if counter, isCounter := writer.(WriteCounter); isCounter { upstreamWriter, upstreamCountFunc := counter.UnwrapWriter() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountWriter(upstreamWriter, countFunc) } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return writer, countFunc + } + switch u := writer.(type) { + case syscall.Conn, SyscallWriter: + // In our use cases, counters is always at the top, so we stop when we encounter syscall conn + return writer, countFunc + case WithUpstreamWriter: + return UnwrapCountWriter(u.UpstreamWriter().(io.Writer), countFunc) + case common.WithUpstream: + return UnwrapCountWriter(u.Upstream().(io.Writer), countFunc) + } return writer, countFunc } func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (PacketReader, []CountFunc) { - reader = UnwrapPacketReader(reader) if counter, isCounter := reader.(PacketReadCounter); isCounter { upstreamReader, upstreamCountFunc := counter.UnwrapPacketReader() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountPacketReader(upstreamReader, countFunc) } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return reader, countFunc + } + switch u := reader.(type) { + case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn: + // In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter + return reader, countFunc + case WithUpstreamReader: + return UnwrapCountPacketReader(u.UpstreamReader().(PacketReader), countFunc) + case common.WithUpstream: + return UnwrapCountPacketReader(u.Upstream().(PacketReader), countFunc) + } return reader, countFunc } @@ -63,5 +99,17 @@ func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (Packet countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountPacketWriter(upstreamWriter, countFunc) } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return writer, countFunc + } + switch u := writer.(type) { + case syscall.Conn: + // In our use cases, counters is always at the top, so we stop when we encounter syscall conn + return writer, countFunc + case WithUpstreamWriter: + return UnwrapCountPacketWriter(u.UpstreamWriter().(PacketWriter), countFunc) + case common.WithUpstream: + return UnwrapCountPacketWriter(u.Upstream().(PacketWriter), countFunc) + } return writer, countFunc } diff --git a/common/network/direct.go b/common/network/direct.go index e5b5a8324..586560b9b 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -1,6 +1,10 @@ package network import ( + "io" + "syscall" + + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" ) @@ -109,3 +113,90 @@ type VectorisedPacketReadWaiter interface { type VectorisedPacketReadWaitCreator interface { CreateVectorisedPacketReadWaiter() (VectorisedPacketReadWaiter, bool) } + +type SyscallReader interface { + SyscallConnForRead() syscall.RawConn + HandleSyscallReadError(inputErr error) ([]byte, error) +} + +func SyscallAvailableForRead(reader io.Reader) bool { + if _, ok := reader.(syscall.Conn); ok { + return true + } + if _, ok := reader.(SyscallReader); ok { + return true + } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return false + } + if u, ok := reader.(WithUpstreamReader); ok { + return SyscallAvailableForRead(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return SyscallAvailableForRead(u.Upstream().(io.Reader)) + } + return false +} + +func SyscallConnForRead(reader io.Reader) (SyscallReader, syscall.RawConn) { + if c, ok := reader.(syscall.Conn); ok { + conn, _ := c.SyscallConn() + return nil, conn + } + if c, ok := reader.(SyscallReader); ok { + return c, c.SyscallConnForRead() + } + if u, ok := reader.(ReaderWithUpstream); !ok || u.ReaderReplaceable() { + return nil, nil + } + if u, ok := reader.(WithUpstreamReader); ok { + return SyscallConnForRead(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return SyscallConnForRead(u.Upstream().(io.Reader)) + } + return nil, nil +} + +type SyscallWriter interface { + SyscallConnForWrite() syscall.RawConn +} + +func SyscallAvailableForWrite(writer io.Writer) bool { + if _, ok := writer.(syscall.Conn); ok { + return true + } + if _, ok := writer.(SyscallWriter); ok { + return true + } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return false + } + if u, ok := writer.(WithUpstreamWriter); ok { + return SyscallAvailableForWrite(u.UpstreamWriter().(io.Writer)) + } + if u, ok := writer.(common.WithUpstream); ok { + return SyscallAvailableForWrite(u.Upstream().(io.Writer)) + } + return false +} + +func SyscallConnForWrite(writer io.Writer) (SyscallWriter, syscall.RawConn) { + if c, ok := writer.(syscall.Conn); ok { + conn, _ := c.SyscallConn() + return nil, conn + } + if c, ok := writer.(SyscallWriter); ok { + return c, c.SyscallConnForWrite() + } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return nil, nil + } + if u, ok := writer.(WithUpstreamWriter); ok { + return SyscallConnForWrite(u.UpstreamWriter().(io.Writer)) + } + if u, ok := writer.(common.WithUpstream); ok { + return SyscallConnForWrite(u.Upstream().(io.Writer)) + } + return nil, nil +} diff --git a/common/network/early.go b/common/network/early.go index 365f292ab..4016dcb71 100644 --- a/common/network/early.go +++ b/common/network/early.go @@ -1,5 +1,38 @@ package network +import ( + "io" + + "github.com/sagernet/sing/common" +) + +// Deprecated: use EarlyReader and EarlyWriter instead. type EarlyConn interface { NeedHandshake() bool } + +type EarlyReader interface { + NeedHandshakeForRead() bool +} + +func NeedHandshakeForRead(reader io.Reader) bool { + if earlyReader, isEarlyReader := common.Cast[EarlyReader](reader); isEarlyReader && earlyReader.NeedHandshakeForRead() { + return true + } + return false +} + +type EarlyWriter interface { + NeedHandshakeForWrite() bool +} + +func NeedHandshakeForWrite(writer io.Writer) bool { + if //goland:noinspection GoDeprecation + earlyConn, isEarlyConn := writer.(EarlyConn); isEarlyConn { + return earlyConn.NeedHandshake() + } + if earlyWriter, isEarlyWriter := common.Cast[EarlyWriter](writer); isEarlyWriter && earlyWriter.NeedHandshakeForWrite() { + return true + } + return false +} diff --git a/common/network/name.go b/common/network/name.go index a0284749f..d74509916 100644 --- a/common/network/name.go +++ b/common/network/name.go @@ -10,11 +10,10 @@ var ErrUnknownNetwork = E.New("unknown network") //goland:noinspection GoNameStartsWithPackageName const ( - NetworkIP = "ip" - NetworkTCP = "tcp" - NetworkUDP = "udp" - NetworkICMPv4 = "icmpv4" - NetworkICMPv6 = "icmpv6" + NetworkIP = "ip" + NetworkTCP = "tcp" + NetworkUDP = "udp" + NetworkICMP = "icmp" ) //goland:noinspection GoNameStartsWithPackageName @@ -23,6 +22,8 @@ func NetworkName(network string) string { return NetworkTCP } else if strings.HasPrefix(network, "udp") { return NetworkUDP + } else if strings.HasPrefix(network, "icmp") { + return NetworkICMP } else if strings.HasPrefix(network, "ip") { return NetworkIP } else { diff --git a/common/tls/config.go b/common/tls/config.go index 7f8be0d70..3bb16416f 100644 --- a/common/tls/config.go +++ b/common/tls/config.go @@ -17,7 +17,7 @@ type Config interface { SetServerName(serverName string) NextProtos() []string SetNextProtos(nextProto []string) - Config() (*STDConfig, error) + STDConfig() (*STDConfig, error) Client(conn net.Conn) (Conn, error) Clone() Config } diff --git a/common/winwlanapi/syscall_windows.go b/common/winwlanapi/syscall_windows.go new file mode 100644 index 000000000..848733891 --- /dev/null +++ b/common/winwlanapi/syscall_windows.go @@ -0,0 +1,23 @@ +//go:build windows + +package winwlanapi + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanopenhandle +//sys wlanOpenHandle(clientVersion uint32, reserved uintptr, negotiatedVersion *uint32, clientHandle *windows.Handle) (ret uint32) = wlanapi.WlanOpenHandle + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanclosehandle +//sys wlanCloseHandle(clientHandle windows.Handle, reserved uintptr) (ret uint32) = wlanapi.WlanCloseHandle + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanenuminterfaces +//sys wlanEnumInterfaces(clientHandle windows.Handle, reserved uintptr, interfaceList **InterfaceInfoList) (ret uint32) = wlanapi.WlanEnumInterfaces + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanqueryinterface +//sys wlanQueryInterface(clientHandle windows.Handle, interfaceGuid *windows.GUID, opCode uint32, reserved uintptr, dataSize *uint32, data *uintptr, opcodeValueType *uint32) (ret uint32) = wlanapi.WlanQueryInterface + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanfreememory +//sys wlanFreeMemory(memory uintptr) = wlanapi.WlanFreeMemory + +// https://learn.microsoft.com/en-us/windows/win32/api/wlanapi/nf-wlanapi-wlanregisternotification +//sys wlanRegisterNotification(clientHandle windows.Handle, notificationSource uint32, ignoreDuplicate bool, callback uintptr, callbackContext uintptr, reserved uintptr, prevNotificationSource *uint32) (ret uint32) = wlanapi.WlanRegisterNotification diff --git a/common/winwlanapi/wlanapi.go b/common/winwlanapi/wlanapi.go new file mode 100644 index 000000000..d128dcc0a --- /dev/null +++ b/common/winwlanapi/wlanapi.go @@ -0,0 +1,161 @@ +//go:build windows + +package winwlanapi + +import ( + "os" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + ClientVersion2 = 2 + + // InterfaceOpcode for WlanQueryInterface + IntfOpcodeCurrentConnection = 7 + + // NotificationSource for WlanRegisterNotification + NotificationSourceNone = 0 + NotificationSourceACM = 0x00000008 + + // NotificationACM codes + NotificationACMConnectionComplete = 10 + NotificationACMDisconnected = 21 + + // InterfaceState + InterfaceStateNotReady = 0 + InterfaceStateConnected = 1 + InterfaceStateAdHocNetworkFormed = 2 + InterfaceStateDisconnecting = 3 + InterfaceStateDisconnected = 4 + InterfaceStateAssociating = 5 + InterfaceStateDiscovering = 6 + InterfaceStateAuthenticating = 7 + + // DOT11_SSID + Dot11SSIDMaxLength = 32 +) + +type Dot11SSID struct { + Length uint32 + SSID [Dot11SSIDMaxLength]byte +} + +type Dot11MacAddress [6]byte + +type AssociationAttributes struct { + SSID Dot11SSID + BSSType uint32 + BSSID Dot11MacAddress + _ [2]byte // padding for 4-byte alignment + PhyType uint32 + PhyIndex uint32 + SignalQuality uint32 + RxRate uint32 + TxRate uint32 +} + +type SecurityAttributes struct { + SecurityEnabled int32 // Windows BOOL is 4 bytes + OneXEnabled int32 + AuthAlgorithm uint32 + CipherAlgorithm uint32 +} + +type ConnectionAttributes struct { + InterfaceState uint32 + ConnectionMode uint32 + ProfileName [256]uint16 + AssociationAttributes AssociationAttributes + SecurityAttributes SecurityAttributes +} + +type InterfaceInfo struct { + InterfaceGUID windows.GUID + InterfaceDescription [256]uint16 + InterfaceState uint32 +} + +type InterfaceInfoList struct { + NumberOfItems uint32 + Index uint32 + InterfaceInfo [1]InterfaceInfo +} + +type NotificationData struct { + NotificationSource uint32 + NotificationCode uint32 + InterfaceGUID windows.GUID + DataSize uint32 + Data uintptr +} + +// NotificationCallback is the type for notification callback functions. +// Use syscall.NewCallback to create a callback from a Go function. +type NotificationCallback func(data *NotificationData, context uintptr) uintptr + +func OpenHandle() (windows.Handle, error) { + var negotiatedVersion uint32 + var handle windows.Handle + ret := wlanOpenHandle(ClientVersion2, 0, &negotiatedVersion, &handle) + if ret != 0 { + return 0, os.NewSyscallError("WlanOpenHandle", windows.Errno(ret)) + } + return handle, nil +} + +func CloseHandle(handle windows.Handle) error { + ret := wlanCloseHandle(handle, 0) + if ret != 0 { + return os.NewSyscallError("WlanCloseHandle", windows.Errno(ret)) + } + return nil +} + +func EnumInterfaces(handle windows.Handle) ([]InterfaceInfo, error) { + var list *InterfaceInfoList + ret := wlanEnumInterfaces(handle, 0, &list) + if ret != 0 { + return nil, os.NewSyscallError("WlanEnumInterfaces", windows.Errno(ret)) + } + defer wlanFreeMemory(uintptr(unsafe.Pointer(list))) + + if list.NumberOfItems == 0 { + return nil, nil + } + + interfaces := unsafe.Slice(&list.InterfaceInfo[0], list.NumberOfItems) + result := make([]InterfaceInfo, list.NumberOfItems) + copy(result, interfaces) + return result, nil +} + +func QueryCurrentConnection(handle windows.Handle, interfaceGUID *windows.GUID) (*ConnectionAttributes, error) { + var dataSize uint32 + var data uintptr + var opcodeValueType uint32 + + ret := wlanQueryInterface(handle, interfaceGUID, IntfOpcodeCurrentConnection, 0, &dataSize, &data, &opcodeValueType) + if ret != 0 { + return nil, os.NewSyscallError("WlanQueryInterface", windows.Errno(ret)) + } + defer wlanFreeMemory(data) + + attrs := (*ConnectionAttributes)(unsafe.Pointer(data)) + result := *attrs + return &result, nil +} + +func RegisterNotification(handle windows.Handle, notificationSource uint32, callback uintptr, context uintptr) error { + var prevSource uint32 + ret := wlanRegisterNotification(handle, notificationSource, false, callback, context, 0, &prevSource) + if ret != 0 { + return os.NewSyscallError("WlanRegisterNotification", windows.Errno(ret)) + } + return nil +} + +func UnregisterNotification(handle windows.Handle) error { + return RegisterNotification(handle, NotificationSourceNone, 0, 0) +} diff --git a/common/winwlanapi/wlanapi_test.go b/common/winwlanapi/wlanapi_test.go new file mode 100644 index 000000000..acae53893 --- /dev/null +++ b/common/winwlanapi/wlanapi_test.go @@ -0,0 +1,119 @@ +//go:build windows + +package winwlanapi + +import ( + "testing" +) + +func TestOpenHandle(t *testing.T) { + handle, err := OpenHandle() + if err != nil { + t.Skipf("WLAN service not available: %v", err) + } + defer CloseHandle(handle) + + if handle == 0 { + t.Error("expected non-zero handle") + } +} + +func TestEnumInterfaces(t *testing.T) { + handle, err := OpenHandle() + if err != nil { + t.Skipf("WLAN service not available: %v", err) + } + defer CloseHandle(handle) + + interfaces, err := EnumInterfaces(handle) + if err != nil { + t.Fatalf("EnumInterfaces failed: %v", err) + } + + t.Logf("Found %d WLAN interface(s)", len(interfaces)) + for i, iface := range interfaces { + description := windowsUTF16ToString(iface.InterfaceDescription[:]) + t.Logf("Interface %d: %s (state=%d)", i, description, iface.InterfaceState) + } +} + +func TestQueryCurrentConnection(t *testing.T) { + handle, err := OpenHandle() + if err != nil { + t.Skipf("WLAN service not available: %v", err) + } + defer CloseHandle(handle) + + interfaces, err := EnumInterfaces(handle) + if err != nil { + t.Fatalf("EnumInterfaces failed: %v", err) + } + + if len(interfaces) == 0 { + t.Skip("no WLAN interfaces available") + } + + for _, iface := range interfaces { + if iface.InterfaceState != InterfaceStateConnected { + continue + } + + guid := iface.InterfaceGUID + attrs, err := QueryCurrentConnection(handle, &guid) + if err != nil { + t.Errorf("QueryCurrentConnection failed: %v", err) + continue + } + + ssidLen := attrs.AssociationAttributes.SSID.Length + if ssidLen > 0 && ssidLen <= Dot11SSIDMaxLength { + ssid := string(attrs.AssociationAttributes.SSID.SSID[:ssidLen]) + bssid := attrs.AssociationAttributes.BSSID + t.Logf("Connected to SSID: %q, BSSID: %02X:%02X:%02X:%02X:%02X:%02X", + ssid, bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]) + } + return + } + + t.Log("no connected WLAN interface found") +} + +func TestCloseHandle(t *testing.T) { + handle, err := OpenHandle() + if err != nil { + t.Skipf("WLAN service not available: %v", err) + } + + err = CloseHandle(handle) + if err != nil { + t.Errorf("CloseHandle failed: %v", err) + } + + // closing again should fail + err = CloseHandle(handle) + if err == nil { + t.Error("expected error when closing already closed handle") + } +} + +func windowsUTF16ToString(s []uint16) string { + for i, c := range s { + if c == 0 { + return string(utf16ToRunes(s[:i])) + } + } + return string(utf16ToRunes(s)) +} + +func utf16ToRunes(s []uint16) []rune { + runes := make([]rune, 0, len(s)) + for i := 0; i < len(s); i++ { + if s[i] >= 0xD800 && s[i] <= 0xDBFF && i+1 < len(s) && s[i+1] >= 0xDC00 && s[i+1] <= 0xDFFF { + runes = append(runes, rune((int(s[i])-0xD800)<<10+(int(s[i+1])-0xDC00)+0x10000)) + i++ + } else { + runes = append(runes, rune(s[i])) + } + } + return runes +} diff --git a/common/winwlanapi/zsyscall_windows.go b/common/winwlanapi/zsyscall_windows.go new file mode 100644 index 000000000..18360c8a4 --- /dev/null +++ b/common/winwlanapi/zsyscall_windows.go @@ -0,0 +1,88 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winwlanapi + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modwlanapi = windows.NewLazySystemDLL("wlanapi.dll") + + procWlanCloseHandle = modwlanapi.NewProc("WlanCloseHandle") + procWlanEnumInterfaces = modwlanapi.NewProc("WlanEnumInterfaces") + procWlanFreeMemory = modwlanapi.NewProc("WlanFreeMemory") + procWlanOpenHandle = modwlanapi.NewProc("WlanOpenHandle") + procWlanQueryInterface = modwlanapi.NewProc("WlanQueryInterface") + procWlanRegisterNotification = modwlanapi.NewProc("WlanRegisterNotification") +) + +func wlanCloseHandle(clientHandle windows.Handle, reserved uintptr) (ret uint32) { + r0, _, _ := syscall.Syscall(procWlanCloseHandle.Addr(), 2, uintptr(clientHandle), uintptr(reserved), 0) + ret = uint32(r0) + return +} + +func wlanEnumInterfaces(clientHandle windows.Handle, reserved uintptr, interfaceList **InterfaceInfoList) (ret uint32) { + r0, _, _ := syscall.Syscall(procWlanEnumInterfaces.Addr(), 3, uintptr(clientHandle), uintptr(reserved), uintptr(unsafe.Pointer(interfaceList))) + ret = uint32(r0) + return +} + +func wlanFreeMemory(memory uintptr) { + syscall.Syscall(procWlanFreeMemory.Addr(), 1, uintptr(memory), 0, 0) + return +} + +func wlanOpenHandle(clientVersion uint32, reserved uintptr, negotiatedVersion *uint32, clientHandle *windows.Handle) (ret uint32) { + r0, _, _ := syscall.Syscall6(procWlanOpenHandle.Addr(), 4, uintptr(clientVersion), uintptr(reserved), uintptr(unsafe.Pointer(negotiatedVersion)), uintptr(unsafe.Pointer(clientHandle)), 0, 0) + ret = uint32(r0) + return +} + +func wlanQueryInterface(clientHandle windows.Handle, interfaceGuid *windows.GUID, opCode uint32, reserved uintptr, dataSize *uint32, data *uintptr, opcodeValueType *uint32) (ret uint32) { + r0, _, _ := syscall.Syscall9(procWlanQueryInterface.Addr(), 7, uintptr(clientHandle), uintptr(unsafe.Pointer(interfaceGuid)), uintptr(opCode), uintptr(reserved), uintptr(unsafe.Pointer(dataSize)), uintptr(unsafe.Pointer(data)), uintptr(unsafe.Pointer(opcodeValueType)), 0, 0) + ret = uint32(r0) + return +} + +func wlanRegisterNotification(clientHandle windows.Handle, notificationSource uint32, ignoreDuplicate bool, callback uintptr, callbackContext uintptr, reserved uintptr, prevNotificationSource *uint32) (ret uint32) { + var _p0 uint32 + if ignoreDuplicate { + _p0 = 1 + } + r0, _, _ := syscall.Syscall9(procWlanRegisterNotification.Addr(), 7, uintptr(clientHandle), uintptr(notificationSource), uintptr(_p0), uintptr(callback), uintptr(callbackContext), uintptr(reserved), uintptr(unsafe.Pointer(prevNotificationSource)), 0, 0) + ret = uint32(r0) + return +} diff --git a/service/filemanager/default.go b/service/filemanager/default.go index 372bc0162..717bd0636 100644 --- a/service/filemanager/default.go +++ b/service/filemanager/default.go @@ -4,7 +4,6 @@ import ( "context" "os" "path/filepath" - "strings" "syscall" "github.com/sagernet/sing/common/rw" @@ -36,7 +35,7 @@ func WithDefault(ctx context.Context, basePath string, tempPath string, userID i } func (m *defaultManager) BasePath(name string) string { - if m.basePath == "" || strings.HasPrefix(name, "/") { + if m.basePath == "" || filepath.IsAbs(name) { return name } return filepath.Join(m.basePath, name)