From 2bb4425f8872f47dd997edab283e497eed75081a Mon Sep 17 00:00:00 2001 From: Denis Tingaikin Date: Wed, 22 May 2024 17:10:41 +0300 Subject: [PATCH 1/4] Protect grpcfd from edge cases Signed-off-by: Denis Tingaikin --- connwrap.go | 32 +++++++++++++---------- connwrap_linux.go | 14 +++------- connwrap_notlinux.go | 14 +++------- per_rpc_transport_credentials.go | 16 +++++++----- per_rpc_transport_credentials_linux.go | 3 ++- per_rpc_transport_credentials_notlinux.go | 3 ++- 6 files changed, 39 insertions(+), 43 deletions(-) diff --git a/connwrap.go b/connwrap.go index 9a0bdf4..24b3792 100644 --- a/connwrap.go +++ b/connwrap.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !windows // +build !windows package grpcfd @@ -170,7 +171,7 @@ func (w *connWrap) Write(b []byte) (int, error) { } func (w *connWrap) SendFD(fd uintptr) <-chan error { - errCh := make(chan error, 1) + errCh := make(chan error, 10) // Dup the fd because we have no way of knowing what the caller will do with it between // now and when we can send it fd, _, err := syscall.Syscall(syscall.SYS_FCNTL, fd, uintptr(syscall.F_DUPFD), 0) @@ -187,7 +188,7 @@ func (w *connWrap) SendFD(fd uintptr) <-chan error { } func (w *connWrap) SendFile(file SyscallConn) <-chan error { - errCh := make(chan error, 1) + errCh := make(chan error, 10) raw, err := file.SyscallConn() if err != nil { errCh <- errors.Wrapf(err, "unable to retrieve syscall.RawConn for src %+v", file) @@ -202,16 +203,17 @@ func (w *connWrap) SendFile(file SyscallConn) <-chan error { close(errCh) return } - go func(errChIn <-chan error, errChOut chan<- error) { - for err := range errChIn { - errChOut <- err - } - close(errChOut) - }(w.SendFD(fd), errCh) + go joinErrChs(w.SendFD(fd), errCh) }) + if err != nil { - errCh <- err - close(errCh) + // Return a separate channel to not conflict with goroutine from the raw.Control + // As an alternative, mutex can be used, but it can affect performance. + // + // In some cases, errCh can't be closed, but it's fine. https://groups.google.com/g/golang-nuts/c/pZwdYRGxCIk/m/qpbHxRRPJdUJ + var resCh = make(chan error, 1) + resCh <- err + return resCh } return errCh } @@ -229,7 +231,7 @@ func (w *connWrap) String() string { } func (w *connWrap) RecvFD(dev, ino uint64) <-chan uintptr { - fdCh := make(chan uintptr, 1) + fdCh := make(chan uintptr, 10) w.recvExecutor.AsyncExec(func() { key := inodeKey{ dev: dev, @@ -266,7 +268,7 @@ func (w *connWrap) RecvFDByURL(urlStr string) (<-chan uintptr, error) { } func (w *connWrap) RecvFile(dev, ino uint64) <-chan *os.File { - fileCh := make(chan *os.File, 1) + fileCh := make(chan *os.File, 10) go func(fdCh <-chan uintptr, fileCh chan<- *os.File) { for fd := range fdCh { if runtime.GOOS == "linux" { @@ -352,14 +354,16 @@ func (w *connWrap) Read(b []byte) (n int, err error) { } // FromPeer - return grpcfd.FDTransceiver from peer.Peer -// ok is true of successful, false otherwise +// +// ok is true of successful, false otherwise func FromPeer(p *peer.Peer) (transceiver FDTransceiver, ok bool) { transceiver, ok = p.Addr.(FDTransceiver) return transceiver, ok } // FromContext - return grpcfd.FDTransceiver from context.Context -// ok is true of successful, false otherwise +// +// ok is true of successful, false otherwise func FromContext(ctx context.Context) (transceiver FDTransceiver, ok bool) { p, ok := peer.FromContext(ctx) if !ok { diff --git a/connwrap_linux.go b/connwrap_linux.go index 5a1f7a8..bb61711 100644 --- a/connwrap_linux.go +++ b/connwrap_linux.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package grpcfd @@ -25,7 +26,7 @@ import ( ) func (w *connWrap) SendFilename(filename string) <-chan error { - errCh := make(chan error, 1) + errCh := make(chan error, 10) file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec if err != nil { errCh <- err @@ -33,15 +34,8 @@ func (w *connWrap) SendFilename(filename string) <-chan error { return errCh } go func(errChIn <-chan error, errChOut chan<- error) { - for err := range errChIn { - errChOut <- err - } - err := file.Close() - if err != nil { - errChOut <- err - } - close(errChOut) + joinErrChs(errChIn, errChOut) + _ = file.Close() }(w.SendFile(file), errCh) - _ = file.Close() return errCh } diff --git a/connwrap_notlinux.go b/connwrap_notlinux.go index 613ae1b..f45ac84 100644 --- a/connwrap_notlinux.go +++ b/connwrap_notlinux.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !linux && !windows // +build !linux,!windows package grpcfd @@ -23,7 +24,7 @@ import ( ) func (w *connWrap) SendFilename(filename string) <-chan error { - errCh := make(chan error, 1) + errCh := make(chan error, 10) // Note: this will fail in most cases for 'unopenable' files (like unix file sockets). See use of O_PATH in connwrap_linux.go for // the trick that makes this work in Linux file, err := os.Open(filename) // #nosec @@ -33,15 +34,8 @@ func (w *connWrap) SendFilename(filename string) <-chan error { return errCh } go func(errChIn <-chan error, errChOut chan<- error) { - for err := range errChIn { - errChOut <- err - } - err := file.Close() - if err != nil { - errChOut <- err - } - close(errChOut) + joinErrChs(errChIn, errChOut) + _ = file.Close() }(w.SendFile(file), errCh) - _ = file.Close() return errCh } diff --git a/per_rpc_transport_credentials.go b/per_rpc_transport_credentials.go index da4bec0..be0979f 100644 --- a/per_rpc_transport_credentials.go +++ b/per_rpc_transport_credentials.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !windows // +build !windows package grpcfd @@ -61,7 +62,7 @@ func (w *wrapPerRPCCredentials) RequireTransportSecurity() bool { } func (w *wrapPerRPCCredentials) SendFD(fd uintptr) <-chan error { - out := make(chan error, 1) + out := make(chan error, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinErrChs(w.FDTransceiver.SendFD(fd), out) @@ -75,7 +76,7 @@ func (w *wrapPerRPCCredentials) SendFD(fd uintptr) <-chan error { } func (w *wrapPerRPCCredentials) SendFile(file SyscallConn) <-chan error { - out := make(chan error, 1) + out := make(chan error, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinErrChs(w.FDTransceiver.SendFile(file), out) @@ -89,7 +90,7 @@ func (w *wrapPerRPCCredentials) SendFile(file SyscallConn) <-chan error { } func (w *wrapPerRPCCredentials) RecvFD(dev, inode uint64) <-chan uintptr { - out := make(chan uintptr, 1) + out := make(chan uintptr, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinFDChs(w.FDTransceiver.RecvFD(dev, inode), out) @@ -103,7 +104,7 @@ func (w *wrapPerRPCCredentials) RecvFD(dev, inode uint64) <-chan uintptr { } func (w *wrapPerRPCCredentials) RecvFile(dev, ino uint64) <-chan *os.File { - out := make(chan *os.File, 1) + out := make(chan *os.File, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinFileChs(w.FDTransceiver.RecvFile(dev, ino), out) @@ -121,7 +122,7 @@ func (w *wrapPerRPCCredentials) RecvFileByURL(urlStr string) (<-chan *os.File, e if err != nil { return nil, err } - out := make(chan *os.File, 1) + out := make(chan *os.File, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinFileChs(w.FDTransceiver.RecvFile(dev, ino), out) @@ -139,7 +140,7 @@ func (w *wrapPerRPCCredentials) RecvFDByURL(urlStr string) (<-chan uintptr, erro if err != nil { return nil, err } - out := make(chan uintptr, 1) + out := make(chan uintptr, 10) w.executor.AsyncExec(func() { if w.FDTransceiver != nil { go joinFDChs(w.FDTransceiver.RecvFD(dev, ino), out) @@ -195,7 +196,8 @@ func PerRPCCredentialsFromCallOptions(opts ...grpc.CallOption) credentials.PerRP } // FromPerRPCCredentials - return grpcfd.FDTransceiver from credentials.PerRPCCredentials -// ok is true of successful, false otherwise +// +// ok is true of successful, false otherwise func FromPerRPCCredentials(rpcCredentials credentials.PerRPCCredentials) (transceiver FDTransceiver, ok bool) { if transceiver, ok = rpcCredentials.(FDTransceiver); ok { return transceiver, true diff --git a/per_rpc_transport_credentials_linux.go b/per_rpc_transport_credentials_linux.go index fca1abe..5b76738 100644 --- a/per_rpc_transport_credentials_linux.go +++ b/per_rpc_transport_credentials_linux.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package grpcfd @@ -25,7 +26,7 @@ import ( ) func (w *wrapPerRPCCredentials) SendFilename(filename string) <-chan error { - out := make(chan error, 1) + out := make(chan error, 10) file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec if err != nil { out <- err diff --git a/per_rpc_transport_credentials_notlinux.go b/per_rpc_transport_credentials_notlinux.go index e2de30e..6730070 100644 --- a/per_rpc_transport_credentials_notlinux.go +++ b/per_rpc_transport_credentials_notlinux.go @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !linux && !windows // +build !linux,!windows package grpcfd @@ -21,7 +22,7 @@ package grpcfd import "os" func (w *wrapPerRPCCredentials) SendFilename(filename string) <-chan error { - out := make(chan error, 1) + out := make(chan error, 10) // Note: this will fail in most cases for 'unopenable' files (like unix file sockets). See use of O_PATH in per_rpc_transport_credentials_linux.go for // the trick that makes this work in Linux file, err := os.Open(filename) // #nosec From 7d43500ce7ec253e440f4699237093222c549665 Mon Sep 17 00:00:00 2001 From: Denis Tingaikin Date: Wed, 22 May 2024 19:01:06 +0300 Subject: [PATCH 2/4] add debug logs Signed-off-by: Denis Tingaikin --- connwrap.go | 10 ++++++++++ connwrap_linux.go | 7 +++++++ doc.go | 18 ++++++++++++++++++ per_rpc_transport_credentials.go | 9 +++++++++ 4 files changed, 44 insertions(+) diff --git a/connwrap.go b/connwrap.go index 24b3792..36b7c50 100644 --- a/connwrap.go +++ b/connwrap.go @@ -22,9 +22,11 @@ package grpcfd import ( "context" "fmt" + "log" "net" "os" "runtime" + "runtime/debug" "syscall" "github.com/edwarnicke/serialize" @@ -171,6 +173,10 @@ func (w *connWrap) Write(b []byte) (int, error) { } func (w *connWrap) SendFD(fd uintptr) <-chan error { + log.Default().Panicln("grpcfd: SendFD start: " + fmt.Sprint(goid())) + debug.PrintStack() + defer log.Default().Panicln("grpcfd: SendFD end: " + fmt.Sprint(goid())) + errCh := make(chan error, 10) // Dup the fd because we have no way of knowing what the caller will do with it between // now and when we can send it @@ -188,6 +194,10 @@ func (w *connWrap) SendFD(fd uintptr) <-chan error { } func (w *connWrap) SendFile(file SyscallConn) <-chan error { + log.Default().Panicln("grpcfd: SendFile goroutine start: " + fmt.Sprint(goid())) + debug.PrintStack() + defer log.Default().Panicln("grpcfd: SendFile: " + fmt.Sprint(goid())) + errCh := make(chan error, 10) raw, err := file.SyscallConn() if err != nil { diff --git a/connwrap_linux.go b/connwrap_linux.go index bb61711..2910a54 100644 --- a/connwrap_linux.go +++ b/connwrap_linux.go @@ -20,12 +20,17 @@ package grpcfd import ( + "fmt" + "log" "os" "golang.org/x/sys/unix" ) func (w *connWrap) SendFilename(filename string) <-chan error { + log.Default().Panicln("grpcfd: SendFilename start" + fmt.Sprint(goid())) + defer log.Default().Panicln("grpcfd: SendFilename end" + fmt.Sprint(goid())) + errCh := make(chan error, 10) file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec if err != nil { @@ -34,6 +39,8 @@ func (w *connWrap) SendFilename(filename string) <-chan error { return errCh } go func(errChIn <-chan error, errChOut chan<- error) { + log.Default().Panicln("grpcfd: SendFilename goroutine start: " + fmt.Sprint(goid())) + defer log.Default().Panicln("grpcfd: SendFilename end: " + fmt.Sprint(goid())) joinErrChs(errChIn, errChOut) _ = file.Close() }(w.SendFile(file), errCh) diff --git a/doc.go b/doc.go index 48daf9c..bfeab1c 100644 --- a/doc.go +++ b/doc.go @@ -17,3 +17,21 @@ // Package grpcfd provides a TransportCredential that can wrap other TransportCredentials and cause the // peer.Addr to be a FDSender or FDRecver such that it can send or receive files over unix file sockets (if available). package grpcfd + +import ( + "fmt" + "runtime" + "strconv" + "strings" +) + +func goid() int { + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id +} diff --git a/per_rpc_transport_credentials.go b/per_rpc_transport_credentials.go index be0979f..153698a 100644 --- a/per_rpc_transport_credentials.go +++ b/per_rpc_transport_credentials.go @@ -21,6 +21,8 @@ package grpcfd import ( context "context" + "fmt" + "log" "os" "github.com/edwarnicke/serialize" @@ -154,6 +156,8 @@ func (w *wrapPerRPCCredentials) RecvFDByURL(urlStr string) (<-chan uintptr, erro } func joinErrChs(in <-chan error, out chan<- error) { + log.Default().Panicln("grpcfd: joinErrChs start: " + fmt.Sprint(goid())) + defer log.Default().Panicln("grpcfd: joinErrChs end: " + fmt.Sprint(goid())) for err := range in { out <- err } @@ -161,6 +165,8 @@ func joinErrChs(in <-chan error, out chan<- error) { } func joinFileChs(in <-chan *os.File, out chan<- *os.File) { + log.Default().Panicln("grpcfd: joinFileChs start: " + fmt.Sprint(goid())) + defer log.Default().Panicln("grpcfd: joinFileChs end: " + fmt.Sprint(goid())) for file := range in { out <- file } @@ -168,6 +174,9 @@ func joinFileChs(in <-chan *os.File, out chan<- *os.File) { } func joinFDChs(in <-chan uintptr, out chan<- uintptr) { + log.Default().Panicln("grpcfd: joinFDChs start: " + fmt.Sprint(goid())) + defer log.Default().Panicln("grpcfd: joinFDChs end: " + fmt.Sprint(goid())) + for fd := range in { out <- fd } From da0e6dc723ce296ae7d85d8c5ae9066472850a6f Mon Sep 17 00:00:00 2001 From: Denis Tingaikin Date: Wed, 22 May 2024 19:02:46 +0300 Subject: [PATCH 3/4] fix module name Signed-off-by: Denis Tingaikin --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 59abe4b..d447a88 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/edwarnicke/grpcfd +module github.com/denis-tingaikin/grpcfd go 1.14 From b6f9ce063b8d34ec168d75f173cb0e2ba2d8a6ae Mon Sep 17 00:00:00 2001 From: Denis Tingaikin Date: Wed, 22 May 2024 19:59:40 +0300 Subject: [PATCH 4/4] fix logs Signed-off-by: Denis Tingaikin --- connwrap.go | 8 ++++---- connwrap_linux.go | 8 ++++---- per_rpc_transport_credentials.go | 12 ++++++------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/connwrap.go b/connwrap.go index 36b7c50..0c92626 100644 --- a/connwrap.go +++ b/connwrap.go @@ -173,9 +173,9 @@ func (w *connWrap) Write(b []byte) (int, error) { } func (w *connWrap) SendFD(fd uintptr) <-chan error { - log.Default().Panicln("grpcfd: SendFD start: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: SendFD start: " + fmt.Sprint(goid())) debug.PrintStack() - defer log.Default().Panicln("grpcfd: SendFD end: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: SendFD end: " + fmt.Sprint(goid())) errCh := make(chan error, 10) // Dup the fd because we have no way of knowing what the caller will do with it between @@ -194,9 +194,9 @@ func (w *connWrap) SendFD(fd uintptr) <-chan error { } func (w *connWrap) SendFile(file SyscallConn) <-chan error { - log.Default().Panicln("grpcfd: SendFile goroutine start: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: SendFile start: " + fmt.Sprint(goid())) debug.PrintStack() - defer log.Default().Panicln("grpcfd: SendFile: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: SendFile end: " + fmt.Sprint(goid())) errCh := make(chan error, 10) raw, err := file.SyscallConn() diff --git a/connwrap_linux.go b/connwrap_linux.go index 2910a54..00712d7 100644 --- a/connwrap_linux.go +++ b/connwrap_linux.go @@ -28,8 +28,8 @@ import ( ) func (w *connWrap) SendFilename(filename string) <-chan error { - log.Default().Panicln("grpcfd: SendFilename start" + fmt.Sprint(goid())) - defer log.Default().Panicln("grpcfd: SendFilename end" + fmt.Sprint(goid())) + log.Default().Println("grpcfd: SendFilename start" + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: SendFilename end" + fmt.Sprint(goid())) errCh := make(chan error, 10) file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec @@ -39,8 +39,8 @@ func (w *connWrap) SendFilename(filename string) <-chan error { return errCh } go func(errChIn <-chan error, errChOut chan<- error) { - log.Default().Panicln("grpcfd: SendFilename goroutine start: " + fmt.Sprint(goid())) - defer log.Default().Panicln("grpcfd: SendFilename end: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: SendFilename goroutine start: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: SendFilename end: " + fmt.Sprint(goid())) joinErrChs(errChIn, errChOut) _ = file.Close() }(w.SendFile(file), errCh) diff --git a/per_rpc_transport_credentials.go b/per_rpc_transport_credentials.go index 153698a..20f0679 100644 --- a/per_rpc_transport_credentials.go +++ b/per_rpc_transport_credentials.go @@ -156,8 +156,8 @@ func (w *wrapPerRPCCredentials) RecvFDByURL(urlStr string) (<-chan uintptr, erro } func joinErrChs(in <-chan error, out chan<- error) { - log.Default().Panicln("grpcfd: joinErrChs start: " + fmt.Sprint(goid())) - defer log.Default().Panicln("grpcfd: joinErrChs end: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: joinErrChs start: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: joinErrChs end: " + fmt.Sprint(goid())) for err := range in { out <- err } @@ -165,8 +165,8 @@ func joinErrChs(in <-chan error, out chan<- error) { } func joinFileChs(in <-chan *os.File, out chan<- *os.File) { - log.Default().Panicln("grpcfd: joinFileChs start: " + fmt.Sprint(goid())) - defer log.Default().Panicln("grpcfd: joinFileChs end: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: joinFileChs start: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: joinFileChs end: " + fmt.Sprint(goid())) for file := range in { out <- file } @@ -174,8 +174,8 @@ func joinFileChs(in <-chan *os.File, out chan<- *os.File) { } func joinFDChs(in <-chan uintptr, out chan<- uintptr) { - log.Default().Panicln("grpcfd: joinFDChs start: " + fmt.Sprint(goid())) - defer log.Default().Panicln("grpcfd: joinFDChs end: " + fmt.Sprint(goid())) + log.Default().Println("grpcfd: joinFDChs start: " + fmt.Sprint(goid())) + defer log.Default().Println("grpcfd: joinFDChs end: " + fmt.Sprint(goid())) for fd := range in { out <- fd