Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions connwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !windows
// +build !windows

package grpcfd

import (
"context"
"fmt"
"log"
"net"
"os"
"runtime"
"runtime/debug"
"syscall"

"github.com/edwarnicke/serialize"
Expand Down Expand Up @@ -170,7 +173,11 @@ func (w *connWrap) Write(b []byte) (int, error) {
}

func (w *connWrap) SendFD(fd uintptr) <-chan error {
errCh := make(chan error, 1)
log.Default().Println("grpcfd: SendFD start: " + fmt.Sprint(goid()))
debug.PrintStack()
defer log.Default().Println("grpcfd: SendFD end: " + fmt.Sprint(goid()))

errCh := make(chan error, 10)
// Dup the fd because we have no way of knowing what the caller will do with it between
// now and when we can send it
fd, _, err := syscall.Syscall(syscall.SYS_FCNTL, fd, uintptr(syscall.F_DUPFD), 0)
Expand All @@ -187,7 +194,11 @@ func (w *connWrap) SendFD(fd uintptr) <-chan error {
}

func (w *connWrap) SendFile(file SyscallConn) <-chan error {
errCh := make(chan error, 1)
log.Default().Println("grpcfd: SendFile start: " + fmt.Sprint(goid()))
debug.PrintStack()
defer log.Default().Println("grpcfd: SendFile end: " + fmt.Sprint(goid()))

errCh := make(chan error, 10)
raw, err := file.SyscallConn()
if err != nil {
errCh <- errors.Wrapf(err, "unable to retrieve syscall.RawConn for src %+v", file)
Expand All @@ -202,16 +213,17 @@ func (w *connWrap) SendFile(file SyscallConn) <-chan error {
close(errCh)
return
}
go func(errChIn <-chan error, errChOut chan<- error) {
for err := range errChIn {
errChOut <- err
}
close(errChOut)
}(w.SendFD(fd), errCh)
go joinErrChs(w.SendFD(fd), errCh)
})

if err != nil {
errCh <- err
close(errCh)
// Return a separate channel to not conflict with goroutine from the raw.Control
// As an alternative, mutex can be used, but it can affect performance.
//
// In some cases, errCh can't be closed, but it's fine. https://groups.google.com/g/golang-nuts/c/pZwdYRGxCIk/m/qpbHxRRPJdUJ
var resCh = make(chan error, 1)
resCh <- err
return resCh
}
return errCh
}
Expand All @@ -229,7 +241,7 @@ func (w *connWrap) String() string {
}

func (w *connWrap) RecvFD(dev, ino uint64) <-chan uintptr {
fdCh := make(chan uintptr, 1)
fdCh := make(chan uintptr, 10)
w.recvExecutor.AsyncExec(func() {
key := inodeKey{
dev: dev,
Expand Down Expand Up @@ -266,7 +278,7 @@ func (w *connWrap) RecvFDByURL(urlStr string) (<-chan uintptr, error) {
}

func (w *connWrap) RecvFile(dev, ino uint64) <-chan *os.File {
fileCh := make(chan *os.File, 1)
fileCh := make(chan *os.File, 10)
go func(fdCh <-chan uintptr, fileCh chan<- *os.File) {
for fd := range fdCh {
if runtime.GOOS == "linux" {
Expand Down Expand Up @@ -352,14 +364,16 @@ func (w *connWrap) Read(b []byte) (n int, err error) {
}

// FromPeer - return grpcfd.FDTransceiver from peer.Peer
// ok is true of successful, false otherwise
//
// ok is true of successful, false otherwise
func FromPeer(p *peer.Peer) (transceiver FDTransceiver, ok bool) {
transceiver, ok = p.Addr.(FDTransceiver)
return transceiver, ok
}

// FromContext - return grpcfd.FDTransceiver from context.Context
// ok is true of successful, false otherwise
//
// ok is true of successful, false otherwise
func FromContext(ctx context.Context) (transceiver FDTransceiver, ok bool) {
p, ok := peer.FromContext(ctx)
if !ok {
Expand Down
21 changes: 11 additions & 10 deletions connwrap_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build linux
// +build linux

package grpcfd

import (
"fmt"
"log"
"os"

"golang.org/x/sys/unix"
)

func (w *connWrap) SendFilename(filename string) <-chan error {
errCh := make(chan error, 1)
log.Default().Println("grpcfd: SendFilename start" + fmt.Sprint(goid()))
defer log.Default().Println("grpcfd: SendFilename end" + fmt.Sprint(goid()))

errCh := make(chan error, 10)
file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec
if err != nil {
errCh <- err
close(errCh)
return errCh
}
go func(errChIn <-chan error, errChOut chan<- error) {
for err := range errChIn {
errChOut <- err
}
err := file.Close()
if err != nil {
errChOut <- err
}
close(errChOut)
log.Default().Println("grpcfd: SendFilename goroutine start: " + fmt.Sprint(goid()))
defer log.Default().Println("grpcfd: SendFilename end: " + fmt.Sprint(goid()))
joinErrChs(errChIn, errChOut)
_ = file.Close()
}(w.SendFile(file), errCh)
_ = file.Close()
return errCh
}
14 changes: 4 additions & 10 deletions connwrap_notlinux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !linux && !windows
// +build !linux,!windows

package grpcfd
Expand All @@ -23,7 +24,7 @@ import (
)

func (w *connWrap) SendFilename(filename string) <-chan error {
errCh := make(chan error, 1)
errCh := make(chan error, 10)
// Note: this will fail in most cases for 'unopenable' files (like unix file sockets). See use of O_PATH in connwrap_linux.go for
// the trick that makes this work in Linux
file, err := os.Open(filename) // #nosec
Expand All @@ -33,15 +34,8 @@ func (w *connWrap) SendFilename(filename string) <-chan error {
return errCh
}
go func(errChIn <-chan error, errChOut chan<- error) {
for err := range errChIn {
errChOut <- err
}
err := file.Close()
if err != nil {
errChOut <- err
}
close(errChOut)
joinErrChs(errChIn, errChOut)
_ = file.Close()
}(w.SendFile(file), errCh)
_ = file.Close()
return errCh
}
18 changes: 18 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,21 @@
// Package grpcfd provides a TransportCredential that can wrap other TransportCredentials and cause the
// peer.Addr to be a FDSender or FDRecver such that it can send or receive files over unix file sockets (if available).
package grpcfd

import (
"fmt"
"runtime"
"strconv"
"strings"
)

func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/edwarnicke/grpcfd
module github.com/denis-tingaikin/grpcfd

go 1.14

Expand Down
25 changes: 18 additions & 7 deletions per_rpc_transport_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !windows
// +build !windows

package grpcfd

import (
context "context"
"fmt"
"log"
"os"

"github.com/edwarnicke/serialize"
Expand Down Expand Up @@ -61,7 +64,7 @@ func (w *wrapPerRPCCredentials) RequireTransportSecurity() bool {
}

func (w *wrapPerRPCCredentials) SendFD(fd uintptr) <-chan error {
out := make(chan error, 1)
out := make(chan error, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinErrChs(w.FDTransceiver.SendFD(fd), out)
Expand All @@ -75,7 +78,7 @@ func (w *wrapPerRPCCredentials) SendFD(fd uintptr) <-chan error {
}

func (w *wrapPerRPCCredentials) SendFile(file SyscallConn) <-chan error {
out := make(chan error, 1)
out := make(chan error, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinErrChs(w.FDTransceiver.SendFile(file), out)
Expand All @@ -89,7 +92,7 @@ func (w *wrapPerRPCCredentials) SendFile(file SyscallConn) <-chan error {
}

func (w *wrapPerRPCCredentials) RecvFD(dev, inode uint64) <-chan uintptr {
out := make(chan uintptr, 1)
out := make(chan uintptr, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinFDChs(w.FDTransceiver.RecvFD(dev, inode), out)
Expand All @@ -103,7 +106,7 @@ func (w *wrapPerRPCCredentials) RecvFD(dev, inode uint64) <-chan uintptr {
}

func (w *wrapPerRPCCredentials) RecvFile(dev, ino uint64) <-chan *os.File {
out := make(chan *os.File, 1)
out := make(chan *os.File, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinFileChs(w.FDTransceiver.RecvFile(dev, ino), out)
Expand All @@ -121,7 +124,7 @@ func (w *wrapPerRPCCredentials) RecvFileByURL(urlStr string) (<-chan *os.File, e
if err != nil {
return nil, err
}
out := make(chan *os.File, 1)
out := make(chan *os.File, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinFileChs(w.FDTransceiver.RecvFile(dev, ino), out)
Expand All @@ -139,7 +142,7 @@ func (w *wrapPerRPCCredentials) RecvFDByURL(urlStr string) (<-chan uintptr, erro
if err != nil {
return nil, err
}
out := make(chan uintptr, 1)
out := make(chan uintptr, 10)
w.executor.AsyncExec(func() {
if w.FDTransceiver != nil {
go joinFDChs(w.FDTransceiver.RecvFD(dev, ino), out)
Expand All @@ -153,20 +156,27 @@ func (w *wrapPerRPCCredentials) RecvFDByURL(urlStr string) (<-chan uintptr, erro
}

func joinErrChs(in <-chan error, out chan<- error) {
log.Default().Println("grpcfd: joinErrChs start: " + fmt.Sprint(goid()))
defer log.Default().Println("grpcfd: joinErrChs end: " + fmt.Sprint(goid()))
for err := range in {
out <- err
}
close(out)
}

func joinFileChs(in <-chan *os.File, out chan<- *os.File) {
log.Default().Println("grpcfd: joinFileChs start: " + fmt.Sprint(goid()))
defer log.Default().Println("grpcfd: joinFileChs end: " + fmt.Sprint(goid()))
for file := range in {
out <- file
}
close(out)
}

func joinFDChs(in <-chan uintptr, out chan<- uintptr) {
log.Default().Println("grpcfd: joinFDChs start: " + fmt.Sprint(goid()))
defer log.Default().Println("grpcfd: joinFDChs end: " + fmt.Sprint(goid()))

for fd := range in {
out <- fd
}
Expand Down Expand Up @@ -195,7 +205,8 @@ func PerRPCCredentialsFromCallOptions(opts ...grpc.CallOption) credentials.PerRP
}

// FromPerRPCCredentials - return grpcfd.FDTransceiver from credentials.PerRPCCredentials
// ok is true of successful, false otherwise
//
// ok is true of successful, false otherwise
func FromPerRPCCredentials(rpcCredentials credentials.PerRPCCredentials) (transceiver FDTransceiver, ok bool) {
if transceiver, ok = rpcCredentials.(FDTransceiver); ok {
return transceiver, true
Expand Down
3 changes: 2 additions & 1 deletion per_rpc_transport_credentials_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build linux
// +build linux

package grpcfd
Expand All @@ -25,7 +26,7 @@ import (
)

func (w *wrapPerRPCCredentials) SendFilename(filename string) <-chan error {
out := make(chan error, 1)
out := make(chan error, 10)
file, err := os.OpenFile(filename, unix.O_PATH, 0) // #nosec
if err != nil {
out <- err
Expand Down
3 changes: 2 additions & 1 deletion per_rpc_transport_credentials_notlinux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build !linux && !windows
// +build !linux,!windows

package grpcfd

import "os"

func (w *wrapPerRPCCredentials) SendFilename(filename string) <-chan error {
out := make(chan error, 1)
out := make(chan error, 10)
// Note: this will fail in most cases for 'unopenable' files (like unix file sockets). See use of O_PATH in per_rpc_transport_credentials_linux.go for
// the trick that makes this work in Linux
file, err := os.Open(filename) // #nosec
Expand Down