diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 5c99af50f..192e34991 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -5,7 +5,6 @@ import ( "errors" "io" "net" - "syscall" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -59,14 +58,10 @@ func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBuf } func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64, batchSize int) (n int64, err error) { - srcSyscallConn, srcIsSyscall := source.(syscall.Conn) - dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) - if srcIsSyscall && dstIsSyscall { - var handled bool - handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) - if handled { - return - } + var handled bool + handled, n, err = copyDirect(source, destination, readCounters, writeCounters) + if handled { + return } return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter, batchSize) } diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go index bb4d7834d..7cd834122 100644 --- a/common/bufio/copy_direct.go +++ b/common/bufio/copy_direct.go @@ -3,23 +3,22 @@ package bufio import ( "errors" "io" - "syscall" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { - rawSource, err := source.SyscallConn() - if err != nil { +func copyDirect(source io.Reader, destination io.Writer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { + if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) { return } - rawDestination, err := destination.SyscallConn() - if err != nil { + sourceReader, sourceConn := N.SyscallConnForRead(source) + destinationWriter, destinationConn := N.SyscallConnForWrite(destination) + if sourceConn == nil || destinationConn == nil { return } - handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters) + handed, n, err = splice(sourceConn, sourceReader, destinationConn, destinationWriter, readCounters, writeCounters) return } diff --git a/common/bufio/splice_linux.go b/common/bufio/splice_linux.go index 1b08898ad..501543408 100644 --- a/common/bufio/splice_linux.go +++ b/common/bufio/splice_linux.go @@ -11,7 +11,7 @@ import ( const maxSpliceSize = 1 << 20 -func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { handed = true var pipeFDs [2]int err = unix.Pipe2(pipeFDs[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK) @@ -20,12 +20,14 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } defer unix.Close(pipeFDs[0]) defer unix.Close(pipeFDs[1]) - _, _ = unix.FcntlInt(uintptr(pipeFDs[0]), unix.F_SETPIPE_SZ, maxSpliceSize) - var readN int - var readErr error - var writeSize int - var writeErr error + var ( + readN int + readErr error + writeSize int + writeErr error + notFirstTime bool + ) readFunc := func(fd uintptr) (done bool) { p0, p1 := unix.Splice(int(fd), nil, pipeFDs[1], nil, maxSpliceSize, unix.SPLICE_F_NONBLOCK) readN = int(p0) @@ -46,15 +48,28 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } for { err = source.Read(readFunc) - if err != nil { - readErr = err - } if readErr != nil { - if readErr == unix.EINVAL || readErr == unix.ENOSYS { + err = readErr + } + if err != nil { + if sourceReader != nil { + newBuffer, newErr := sourceReader.HandleSyscallReadError(err) + if newErr != nil { + err = newErr + } else { + err = nil + if len(newBuffer) > 0 { + readN, readErr = unix.Write(pipeFDs[1], newBuffer) + if readErr != nil { + err = E.Cause(err, "write handled data") + } + } + } + } else if !notFirstTime && E.IsMulti(err, unix.EINVAL, unix.ENOSYS) { handed = false return } - err = E.Cause(readErr, "splice read") + err = E.Cause(err, "splice read") return } if readN == 0 { @@ -62,18 +77,20 @@ func splice(source syscall.RawConn, destination syscall.RawConn, readCounters [] } writeSize = readN err = destination.Write(writeFunc) - if err != nil { - writeErr = err - } if writeErr != nil { - err = E.Cause(writeErr, "splice write") + err = writeErr + } + if err != nil { + err = E.Cause(err, "splice write") return } + n += int64(readN) for _, readCounter := range readCounters { readCounter(int64(readN)) } for _, writeCounter := range writeCounters { writeCounter(int64(readN)) } + notFirstTime = true } } diff --git a/common/bufio/splice_stub.go b/common/bufio/splice_stub.go index 44c93b5ab..b120284b8 100644 --- a/common/bufio/splice_stub.go +++ b/common/bufio/splice_stub.go @@ -8,6 +8,6 @@ import ( N "github.com/sagernet/sing/common/network" ) -func splice(source syscall.RawConn, destination syscall.RawConn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func splice(source syscall.RawConn, sourceReader N.SyscallReader, destination syscall.RawConn, destinationWriter N.SyscallWriter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { return } diff --git a/common/bufio/wait.go b/common/bufio/wait.go index c32fa3039..6ac0821c7 100644 --- a/common/bufio/wait.go +++ b/common/bufio/wait.go @@ -3,11 +3,11 @@ package bufio import ( "io" + "github.com/sagernet/sing/common" N "github.com/sagernet/sing/common/network" ) func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { - reader = N.UnwrapReader(reader) if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter { return readWaiter, true } @@ -17,11 +17,19 @@ func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { if readWaiter, created := createSyscallReadWaiter(reader); created { return readWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateReadWaiter(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateReadWaiter(u.Upstream().(io.Reader)) + } return nil, false } func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool) { - reader = N.UnwrapReader(reader) if vectorisedReadWaiter, isVectorised := reader.(N.VectorisedReadWaiter); isVectorised { return vectorisedReadWaiter, true } @@ -31,11 +39,19 @@ func CreateVectorisedReadWaiter(reader io.Reader) (N.VectorisedReadWaiter, bool) if vectorisedReadWaiter, created := createVectorisedSyscallReadWaiter(reader); created { return vectorisedReadWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateVectorisedReadWaiter(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateVectorisedReadWaiter(u.Upstream().(io.Reader)) + } return nil, false } func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { - reader = N.UnwrapPacketReader(reader) if readWaiter, isReadWaiter := reader.(N.PacketReadWaiter); isReadWaiter { return readWaiter, true } @@ -45,6 +61,15 @@ func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { if readWaiter, created := createSyscallPacketReadWaiter(reader); created { return readWaiter, true } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreatePacketReadWaiter(u.UpstreamReader().(N.PacketReader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreatePacketReadWaiter(u.Upstream().(N.PacketReader)) + } return nil, false } diff --git a/common/metadata/domain.go b/common/metadata/domain.go index 30ce41157..e4f7b193e 100644 --- a/common/metadata/domain.go +++ b/common/metadata/domain.go @@ -1,6 +1,70 @@ package metadata -import _ "unsafe" // for linkname +// IsDomainName checks if a string is a presentation-format domain name +// (currently restricted to hostname-compatible "preferred name" LDH labels and +// SRV-like "underscore labels"; see golang.org/issue/12421). +// +// This function was originally created here: +// +// https://cs.opensource.google/go/go/+/master:src/net/dnsclient.go;l=76-146;drc=05cbbf985fed823a174bf95cc78a7d44f948fdab +// +// and it's being copy-pasted in order to use the same functionality. In the original package, +// this is a private function that cannot be accessed externally +func IsDomainName(s string) bool { + // The root domain name is valid. See golang.org/issue/45715. + if s == "." { + return true + } -//go:linkname IsDomainName net.isDomainName -func IsDomainName(domain string) bool + // See RFC 1035, RFC 3696. + // Presentation format has dots before every label except the first, and the + // terminal empty label is optional here because we assume fully-qualified + // (absolute) input. We must therefore reserve space for the first and last + // labels' length octets in wire format, where they are necessary and the + // maximum total length is 255. + // So our _effective_ maximum is 253, but 254 is not rejected if the last + // character is a dot. + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + + last := byte('.') + nonNumeric := false // true once we've seen a letter or hyphen + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // Byte before dash cannot be dot. + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + // Byte before dot cannot be dot, dash. + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + + return nonNumeric +} diff --git a/common/network/counter.go b/common/network/counter.go index a20c4d9d1..dc15f3204 100644 --- a/common/network/counter.go +++ b/common/network/counter.go @@ -2,6 +2,9 @@ package network import ( "io" + "syscall" + + "github.com/sagernet/sing/common" ) type CountFunc func(n int64) @@ -27,32 +30,65 @@ type PacketWriteCounter interface { } func UnwrapCountReader(reader io.Reader, countFunc []CountFunc) (io.Reader, []CountFunc) { - reader = UnwrapReader(reader) if counter, isCounter := reader.(ReadCounter); isCounter { upstreamReader, upstreamCountFunc := counter.UnwrapReader() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountReader(upstreamReader, countFunc) } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return reader, countFunc + } + switch u := reader.(type) { + case ReadWaiter, ReadWaitCreator, syscall.Conn, SyscallReader: + // In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter + return reader, countFunc + case WithUpstreamReader: + return UnwrapCountReader(u.UpstreamReader().(io.Reader), countFunc) + case common.WithUpstream: + return UnwrapCountReader(u.Upstream().(io.Reader), countFunc) + } return reader, countFunc } func UnwrapCountWriter(writer io.Writer, countFunc []CountFunc) (io.Writer, []CountFunc) { - writer = UnwrapWriter(writer) if counter, isCounter := writer.(WriteCounter); isCounter { upstreamWriter, upstreamCountFunc := counter.UnwrapWriter() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountWriter(upstreamWriter, countFunc) } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return writer, countFunc + } + switch u := writer.(type) { + case syscall.Conn, SyscallWriter: + // In our use cases, counters is always at the top, so we stop when we encounter syscall conn + return writer, countFunc + case WithUpstreamWriter: + return UnwrapCountWriter(u.UpstreamWriter().(io.Writer), countFunc) + case common.WithUpstream: + return UnwrapCountWriter(u.Upstream().(io.Writer), countFunc) + } return writer, countFunc } func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (PacketReader, []CountFunc) { - reader = UnwrapPacketReader(reader) if counter, isCounter := reader.(PacketReadCounter); isCounter { upstreamReader, upstreamCountFunc := counter.UnwrapPacketReader() countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountPacketReader(upstreamReader, countFunc) } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return reader, countFunc + } + switch u := reader.(type) { + case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn: + // In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter + return reader, countFunc + case WithUpstreamReader: + return UnwrapCountPacketReader(u.UpstreamReader().(PacketReader), countFunc) + case common.WithUpstream: + return UnwrapCountPacketReader(u.Upstream().(PacketReader), countFunc) + } return reader, countFunc } @@ -63,5 +99,17 @@ func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (Packet countFunc = append(countFunc, upstreamCountFunc...) return UnwrapCountPacketWriter(upstreamWriter, countFunc) } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return writer, countFunc + } + switch u := writer.(type) { + case syscall.Conn: + // In our use cases, counters is always at the top, so we stop when we encounter syscall conn + return writer, countFunc + case WithUpstreamWriter: + return UnwrapCountPacketWriter(u.UpstreamWriter().(PacketWriter), countFunc) + case common.WithUpstream: + return UnwrapCountPacketWriter(u.Upstream().(PacketWriter), countFunc) + } return writer, countFunc } diff --git a/common/network/direct.go b/common/network/direct.go index e5b5a8324..586560b9b 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -1,6 +1,10 @@ package network import ( + "io" + "syscall" + + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" ) @@ -109,3 +113,90 @@ type VectorisedPacketReadWaiter interface { type VectorisedPacketReadWaitCreator interface { CreateVectorisedPacketReadWaiter() (VectorisedPacketReadWaiter, bool) } + +type SyscallReader interface { + SyscallConnForRead() syscall.RawConn + HandleSyscallReadError(inputErr error) ([]byte, error) +} + +func SyscallAvailableForRead(reader io.Reader) bool { + if _, ok := reader.(syscall.Conn); ok { + return true + } + if _, ok := reader.(SyscallReader); ok { + return true + } + if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return false + } + if u, ok := reader.(WithUpstreamReader); ok { + return SyscallAvailableForRead(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return SyscallAvailableForRead(u.Upstream().(io.Reader)) + } + return false +} + +func SyscallConnForRead(reader io.Reader) (SyscallReader, syscall.RawConn) { + if c, ok := reader.(syscall.Conn); ok { + conn, _ := c.SyscallConn() + return nil, conn + } + if c, ok := reader.(SyscallReader); ok { + return c, c.SyscallConnForRead() + } + if u, ok := reader.(ReaderWithUpstream); !ok || u.ReaderReplaceable() { + return nil, nil + } + if u, ok := reader.(WithUpstreamReader); ok { + return SyscallConnForRead(u.UpstreamReader().(io.Reader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return SyscallConnForRead(u.Upstream().(io.Reader)) + } + return nil, nil +} + +type SyscallWriter interface { + SyscallConnForWrite() syscall.RawConn +} + +func SyscallAvailableForWrite(writer io.Writer) bool { + if _, ok := writer.(syscall.Conn); ok { + return true + } + if _, ok := writer.(SyscallWriter); ok { + return true + } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return false + } + if u, ok := writer.(WithUpstreamWriter); ok { + return SyscallAvailableForWrite(u.UpstreamWriter().(io.Writer)) + } + if u, ok := writer.(common.WithUpstream); ok { + return SyscallAvailableForWrite(u.Upstream().(io.Writer)) + } + return false +} + +func SyscallConnForWrite(writer io.Writer) (SyscallWriter, syscall.RawConn) { + if c, ok := writer.(syscall.Conn); ok { + conn, _ := c.SyscallConn() + return nil, conn + } + if c, ok := writer.(SyscallWriter); ok { + return c, c.SyscallConnForWrite() + } + if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() { + return nil, nil + } + if u, ok := writer.(WithUpstreamWriter); ok { + return SyscallConnForWrite(u.UpstreamWriter().(io.Writer)) + } + if u, ok := writer.(common.WithUpstream); ok { + return SyscallConnForWrite(u.Upstream().(io.Writer)) + } + return nil, nil +} diff --git a/common/network/early.go b/common/network/early.go index 365f292ab..4016dcb71 100644 --- a/common/network/early.go +++ b/common/network/early.go @@ -1,5 +1,38 @@ package network +import ( + "io" + + "github.com/sagernet/sing/common" +) + +// Deprecated: use EarlyReader and EarlyWriter instead. type EarlyConn interface { NeedHandshake() bool } + +type EarlyReader interface { + NeedHandshakeForRead() bool +} + +func NeedHandshakeForRead(reader io.Reader) bool { + if earlyReader, isEarlyReader := common.Cast[EarlyReader](reader); isEarlyReader && earlyReader.NeedHandshakeForRead() { + return true + } + return false +} + +type EarlyWriter interface { + NeedHandshakeForWrite() bool +} + +func NeedHandshakeForWrite(writer io.Writer) bool { + if //goland:noinspection GoDeprecation + earlyConn, isEarlyConn := writer.(EarlyConn); isEarlyConn { + return earlyConn.NeedHandshake() + } + if earlyWriter, isEarlyWriter := common.Cast[EarlyWriter](writer); isEarlyWriter && earlyWriter.NeedHandshakeForWrite() { + return true + } + return false +} diff --git a/common/network/name.go b/common/network/name.go index a0284749f..d74509916 100644 --- a/common/network/name.go +++ b/common/network/name.go @@ -10,11 +10,10 @@ var ErrUnknownNetwork = E.New("unknown network") //goland:noinspection GoNameStartsWithPackageName const ( - NetworkIP = "ip" - NetworkTCP = "tcp" - NetworkUDP = "udp" - NetworkICMPv4 = "icmpv4" - NetworkICMPv6 = "icmpv6" + NetworkIP = "ip" + NetworkTCP = "tcp" + NetworkUDP = "udp" + NetworkICMP = "icmp" ) //goland:noinspection GoNameStartsWithPackageName @@ -23,6 +22,8 @@ func NetworkName(network string) string { return NetworkTCP } else if strings.HasPrefix(network, "udp") { return NetworkUDP + } else if strings.HasPrefix(network, "icmp") { + return NetworkICMP } else if strings.HasPrefix(network, "ip") { return NetworkIP } else { diff --git a/common/tls/config.go b/common/tls/config.go index 7f8be0d70..3bb16416f 100644 --- a/common/tls/config.go +++ b/common/tls/config.go @@ -17,7 +17,7 @@ type Config interface { SetServerName(serverName string) NextProtos() []string SetNextProtos(nextProto []string) - Config() (*STDConfig, error) + STDConfig() (*STDConfig, error) Client(conn net.Conn) (Conn, error) Clone() Config }