diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2467191d..eb52df90 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -29,7 +29,7 @@ jobs: # Needed for the example-test to run. GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - go test -cover -v ./... + go test -race -cover -v ./... lint: name: Lint diff --git a/Makefile b/Makefile index c1117278..45114a04 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ lint: internal/lint/golangci-lint run ./... --fix check: lint - go test -cover ./... + go test -race -cover ./... go mod tidy .PHONY: example diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 4fe01137..fe0a1080 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -74,7 +74,7 @@ func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) if !ok { return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") } - dataChan_ <- wsResp + graphql.SafeSend(dataChan_, wsResp) return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 158aa87b..76621526 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -60,7 +60,7 @@ func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.Ra if !ok { return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") } - dataChan_ <- wsResp + graphql.SafeSend(dataChan_, wsResp) return nil } diff --git a/graphql/client.go b/graphql/client.go index d2a0181d..f2271ab6 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -363,3 +363,10 @@ func (c *client) createGetRequest(req *Request) (*http.Request, error) { return httpReq, nil } + +func SafeSend[T any](dataChan_ chan T, wsResp T) { + defer func() { + _ = recover() + }() + dataChan_ <- wsResp +} diff --git a/graphql/subscription.go b/graphql/subscription.go index 9d39d791..6c241e67 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -37,13 +37,8 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { if !success { return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID) } - hasBeenUnsubscribed := unsub.hasBeenUnsubscribed unsub.hasBeenUnsubscribed = true s.map_[subscriptionID] = unsub - - if !hasBeenUnsubscribed { - reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close() - } return nil } @@ -61,3 +56,23 @@ func (s *subscriptionMap) Delete(subscriptionID string) { defer s.Unlock() delete(s.map_, subscriptionID) } + +func (s *subscriptionMap) GetOrClose(subscriptionID string, subscriptionType string) (*subscription, error) { + s.Lock() + defer s.Unlock() + sub, success := s.map_[subscriptionID] + if !success { + return nil, fmt.Errorf("received message for unknown subscription ID '%s'", subscriptionID) + } + if sub.hasBeenUnsubscribed { + return nil, nil + } + if subscriptionType == webSocketTypeComplete { + sub.hasBeenUnsubscribed = true + s.map_[subscriptionID] = sub + reflect.ValueOf(sub.interfaceChan).Close() + return nil, nil + } + + return &sub, nil +} diff --git a/graphql/websocket.go b/graphql/websocket.go index 07e209d4..abc63295 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -6,9 +6,9 @@ import ( "encoding/json" "fmt" "net/http" - "reflect" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -51,7 +51,7 @@ type webSocketClient struct { connParams map[string]interface{} errChan chan error subscriptions subscriptionMap - isClosing bool + isClosing atomic.Bool sync.Mutex } @@ -107,14 +107,14 @@ func (w *webSocketClient) waitForConnAck() error { func (w *webSocketClient) handleErr(err error) { w.Lock() defer w.Unlock() - if !w.isClosing { + if !w.isClosing.Load() { w.errChan <- err } } func (w *webSocketClient) listenWebSocket() { for { - if w.isClosing { + if w.isClosing.Load() { return } _, message, err := w.conn.ReadMessage() @@ -139,22 +139,13 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if wsMsg.ID == "" { // e.g. keep-alive messages return nil } - w.subscriptions.Lock() - defer w.subscriptions.Unlock() - sub, success := w.subscriptions.map_[wsMsg.ID] - if !success { - return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) - } - if sub.hasBeenUnsubscribed { - return nil + sub, err := w.subscriptions.GetOrClose(wsMsg.ID, wsMsg.Type) + if err != nil { + return err } - if wsMsg.Type == webSocketTypeComplete { - sub.hasBeenUnsubscribed = true - w.subscriptions.map_[wsMsg.ID] = sub - reflect.ValueOf(sub.interfaceChan).Close() + if sub == nil { return nil } - return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) } @@ -208,7 +199,7 @@ func (w *webSocketClient) Close() error { } w.Lock() defer w.Unlock() - w.isClosing = true + w.isClosing.Store(true) close(w.errChan) return w.conn.Close() } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index fd034833..68573354 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3156,7 +3156,7 @@ func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) err if !ok { return errors.New("failed to cast interface into 'chan countWsResponse'") } - dataChan_ <- wsResp + graphql.SafeSend(dataChan_, wsResp) return nil } @@ -3204,7 +3204,7 @@ func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMe if !ok { return errors.New("failed to cast interface into 'chan countAuthorizedWsResponse'") } - dataChan_ <- wsResp + graphql.SafeSend(dataChan_, wsResp) return nil } @@ -3252,7 +3252,7 @@ func countCloseForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage if !ok { return errors.New("failed to cast interface into 'chan countCloseWsResponse'") } - dataChan_ <- wsResp + graphql.SafeSend(dataChan_, wsResp) return nil } diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 88e92380..174b5a30 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -252,10 +252,6 @@ func TestSubscriptionClose(t *testing.T) { _ = `# @genqlient subscription countClose { countClose }` - ctx := context.Background() - server := server.RunServer() - defer server.Close() - cases := []struct { name string unsub bool @@ -272,6 +268,12 @@ func TestSubscriptionClose(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + server := server.RunServer() + defer server.Close() + wsClient := newRoundtripWebSocketClient(t, server.URL) _, err := wsClient.Start(ctx) diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index d8071ae1..89745c26 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -157,33 +157,37 @@ func (m mutationResolver) CreateUser(ctx context.Context, input NewUser) (*User, return &newUser, nil } -func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { +func countTo(ctx context.Context, stopCount int) (<-chan int, error) { respChan := make(chan int, 1) go func(respChan chan int) { defer close(respChan) - counter := 0 - for { - if counter == 10 { + for counter := range stopCount { + select { + case <-ctx.Done(): return + default: + respChan <- counter + time.Sleep(100 * time.Millisecond) } - respChan <- counter - counter++ - time.Sleep(100 * time.Millisecond) } }(respChan) return respChan, nil } +func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { + return countTo(ctx, 10) +} + func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, error) { if getAuthToken(ctx) != "authorized-user-token" { return nil, fmt.Errorf("unauthorized") } - return s.Count(ctx) + return countTo(ctx, 10) } func (s *subscriptionResolver) CountClose(ctx context.Context) (<-chan int, error) { - return s.Count(ctx) + return countTo(ctx, 1000) } const AuthKey = "authToken"