From 50ba8c0442437367eb643383c98df3194ae3a5ba Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 09:48:51 +0100 Subject: [PATCH 1/7] integration test: longer counting to avoid flakiness --- internal/integration/server/server.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index d8071ae1..5079baaf 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -157,33 +157,32 @@ func (m mutationResolver) CreateUser(ctx context.Context, input NewUser) (*User, return &newUser, nil } -func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { +func countTo(stopCount int) (<-chan int, error) { respChan := make(chan int, 1) go func(respChan chan int) { defer close(respChan) - counter := 0 - for { - if counter == 10 { - return - } + for counter := range stopCount { respChan <- counter - counter++ time.Sleep(100 * time.Millisecond) } }(respChan) return respChan, nil } +func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { + return countTo(10) +} + func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, error) { if getAuthToken(ctx) != "authorized-user-token" { return nil, fmt.Errorf("unauthorized") } - return s.Count(ctx) + return countTo(10) } func (s *subscriptionResolver) CountClose(ctx context.Context) (<-chan int, error) { - return s.Count(ctx) + return countTo(1000) } const AuthKey = "authToken" From 8a982d731013f0fad38fa7587fa86b3cc387687a Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 14:47:49 +0100 Subject: [PATCH 2/7] race --- .github/workflows/go.yml | 2 +- Makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2467191d..eb52df90 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -29,7 +29,7 @@ jobs: # Needed for the example-test to run. GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - go test -cover -v ./... + go test -race -cover -v ./... lint: name: Lint diff --git a/Makefile b/Makefile index c1117278..45114a04 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ lint: internal/lint/golangci-lint run ./... --fix check: lint - go test -cover ./... + go test -race -cover ./... go mod tidy .PHONY: example From 122aba24eb441eb84f211f95809eb0e04b55fad8 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 16:03:26 +0100 Subject: [PATCH 3/7] closing --- graphql/websocket.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/graphql/websocket.go b/graphql/websocket.go index 07e209d4..52cea917 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -51,7 +52,7 @@ type webSocketClient struct { connParams map[string]interface{} errChan chan error subscriptions subscriptionMap - isClosing bool + isClosing atomic.Bool sync.Mutex } @@ -107,14 +108,14 @@ func (w *webSocketClient) waitForConnAck() error { func (w *webSocketClient) handleErr(err error) { w.Lock() defer w.Unlock() - if !w.isClosing { + if !w.isClosing.Load() { w.errChan <- err } } func (w *webSocketClient) listenWebSocket() { for { - if w.isClosing { + if w.isClosing.Load() { return } _, message, err := w.conn.ReadMessage() @@ -208,7 +209,7 @@ func (w *webSocketClient) Close() error { } w.Lock() defer w.Unlock() - w.isClosing = true + w.isClosing.Store(true) close(w.errChan) return w.conn.Close() } From 2523e2a5adbde6f2bdc40bcf47f09d7a5702937b Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 20:19:54 +0100 Subject: [PATCH 4/7] test timing --- internal/integration/integration_test.go | 10 ++++++---- internal/integration/server/server.go | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 88e92380..174b5a30 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -252,10 +252,6 @@ func TestSubscriptionClose(t *testing.T) { _ = `# @genqlient subscription countClose { countClose }` - ctx := context.Background() - server := server.RunServer() - defer server.Close() - cases := []struct { name string unsub bool @@ -272,6 +268,12 @@ func TestSubscriptionClose(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + server := server.RunServer() + defer server.Close() + wsClient := newRoundtripWebSocketClient(t, server.URL) _, err := wsClient.Start(ctx) diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index 5079baaf..668b4b4d 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -157,20 +157,26 @@ func (m mutationResolver) CreateUser(ctx context.Context, input NewUser) (*User, return &newUser, nil } -func countTo(stopCount int) (<-chan int, error) { +func countTo(ctx context.Context, stopCount int) (<-chan int, error) { respChan := make(chan int, 1) go func(respChan chan int) { defer close(respChan) for counter := range stopCount { - respChan <- counter - time.Sleep(100 * time.Millisecond) + select { + case <-ctx.Done(): + fmt.Println("ctx done:", ctx.Err()) + return + default: + respChan <- counter + time.Sleep(100 * time.Millisecond) + } } }(respChan) return respChan, nil } func (s *subscriptionResolver) Count(ctx context.Context) (<-chan int, error) { - return countTo(10) + return countTo(ctx, 10) } func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, error) { @@ -178,11 +184,11 @@ func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, return nil, fmt.Errorf("unauthorized") } - return countTo(10) + return countTo(ctx, 10) } func (s *subscriptionResolver) CountClose(ctx context.Context) (<-chan int, error) { - return countTo(1000) + return countTo(ctx, 1000) } const AuthKey = "authToken" From cd5d2bc9214eb5d591c8eb8379e4312f58bdad5d Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 23:16:22 +0100 Subject: [PATCH 5/7] recover from panic --- generate/operation.go.tmpl | 2 +- ...tion.graphql-SimpleSubscription.graphql.go | 2 +- graphql/client.go | 7 ++++ graphql/websocket.go | 37 ++++++++++++------- internal/integration/generated.go | 6 +-- internal/integration/server/server.go | 1 - 6 files changed, 36 insertions(+), 19 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 4fe01137..2d989825 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -74,7 +74,7 @@ func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) if !ok { return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") } - dataChan_ <- wsResp + graphql.WriteToChannelOrRecover(dataChan_, wsResp) return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index 158aa87b..ac40ab67 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -60,7 +60,7 @@ func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.Ra if !ok { return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") } - dataChan_ <- wsResp + graphql.WriteToChannelOrRecover(dataChan_, wsResp) return nil } diff --git a/graphql/client.go b/graphql/client.go index d2a0181d..ca583672 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -363,3 +363,10 @@ func (c *client) createGetRequest(req *Request) (*http.Request, error) { return httpReq, nil } + +func WriteToChannelOrRecover[T any](dataChan_ chan T, wsResp T) { + defer func() { + _ = recover() + }() + dataChan_ <- wsResp +} diff --git a/graphql/websocket.go b/graphql/websocket.go index 52cea917..462894ae 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -131,6 +131,26 @@ func (w *webSocketClient) listenWebSocket() { } } +func (w *webSocketClient) getSubscriptionOrHandleComplete(subscriptionID string, subscriptionType string) (*subscription, error) { + w.subscriptions.Lock() + defer w.subscriptions.Unlock() + sub, success := w.subscriptions.map_[subscriptionID] + if !success { + return nil, fmt.Errorf("received message for unknown subscription ID '%s'", subscriptionID) + } + if sub.hasBeenUnsubscribed { + return nil, nil + } + if subscriptionType == webSocketTypeComplete { + sub.hasBeenUnsubscribed = true + w.subscriptions.map_[subscriptionID] = sub + reflect.ValueOf(sub.interfaceChan).Close() + return nil, nil + } + + return &sub, nil +} + func (w *webSocketClient) forwardWebSocketData(message []byte) error { var wsMsg webSocketReceiveMessage err := json.Unmarshal(message, &wsMsg) @@ -140,22 +160,13 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if wsMsg.ID == "" { // e.g. keep-alive messages return nil } - w.subscriptions.Lock() - defer w.subscriptions.Unlock() - sub, success := w.subscriptions.map_[wsMsg.ID] - if !success { - return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) - } - if sub.hasBeenUnsubscribed { - return nil + sub, err := w.getSubscriptionOrHandleComplete(wsMsg.ID, wsMsg.Type) + if err != nil { + return err } - if wsMsg.Type == webSocketTypeComplete { - sub.hasBeenUnsubscribed = true - w.subscriptions.map_[wsMsg.ID] = sub - reflect.ValueOf(sub.interfaceChan).Close() + if sub == nil { return nil } - return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index fd034833..e72a817b 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3156,7 +3156,7 @@ func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) err if !ok { return errors.New("failed to cast interface into 'chan countWsResponse'") } - dataChan_ <- wsResp + graphql.WriteToChannelOrRecover(dataChan_, wsResp) return nil } @@ -3204,7 +3204,7 @@ func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMe if !ok { return errors.New("failed to cast interface into 'chan countAuthorizedWsResponse'") } - dataChan_ <- wsResp + graphql.WriteToChannelOrRecover(dataChan_, wsResp) return nil } @@ -3252,7 +3252,7 @@ func countCloseForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage if !ok { return errors.New("failed to cast interface into 'chan countCloseWsResponse'") } - dataChan_ <- wsResp + graphql.WriteToChannelOrRecover(dataChan_, wsResp) return nil } diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index 668b4b4d..89745c26 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -164,7 +164,6 @@ func countTo(ctx context.Context, stopCount int) (<-chan int, error) { for counter := range stopCount { select { case <-ctx.Done(): - fmt.Println("ctx done:", ctx.Err()) return default: respChan <- counter From f9a5e5dfc7155ecf59f2511a9f3606cae1aedfe4 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Tue, 28 Oct 2025 23:51:29 +0100 Subject: [PATCH 6/7] dont close channel on Unsubscribe --- graphql/subscription.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/graphql/subscription.go b/graphql/subscription.go index 9d39d791..b7c42a01 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -2,7 +2,6 @@ package graphql import ( "fmt" - "reflect" "sync" ) @@ -37,13 +36,8 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { if !success { return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID) } - hasBeenUnsubscribed := unsub.hasBeenUnsubscribed unsub.hasBeenUnsubscribed = true s.map_[subscriptionID] = unsub - - if !hasBeenUnsubscribed { - reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close() - } return nil } From 9f24ffcf6cbcd3b575c30cac47374ca384b8bda8 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Wed, 29 Oct 2025 00:43:23 +0100 Subject: [PATCH 7/7] refactor --- generate/operation.go.tmpl | 2 +- ...tion.graphql-SimpleSubscription.graphql.go | 2 +- graphql/client.go | 2 +- graphql/subscription.go | 21 +++++++++++++++++ graphql/websocket.go | 23 +------------------ internal/integration/generated.go | 6 ++--- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index 2d989825..fe0a1080 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -74,7 +74,7 @@ func {{.Name}}ForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) if !ok { return errors.New("failed to cast interface into 'chan {{.Name}}WsResponse'") } - graphql.WriteToChannelOrRecover(dataChan_, wsResp) + graphql.SafeSend(dataChan_, wsResp) return nil } {{end}} diff --git a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go index ac40ab67..76621526 100644 --- a/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go +++ b/generate/testdata/snapshots/TestGenerate-SimpleSubscription.graphql-SimpleSubscription.graphql.go @@ -60,7 +60,7 @@ func SimpleSubscriptionForwardData(interfaceChan interface{}, jsonRawMsg json.Ra if !ok { return errors.New("failed to cast interface into 'chan SimpleSubscriptionWsResponse'") } - graphql.WriteToChannelOrRecover(dataChan_, wsResp) + graphql.SafeSend(dataChan_, wsResp) return nil } diff --git a/graphql/client.go b/graphql/client.go index ca583672..f2271ab6 100644 --- a/graphql/client.go +++ b/graphql/client.go @@ -364,7 +364,7 @@ func (c *client) createGetRequest(req *Request) (*http.Request, error) { return httpReq, nil } -func WriteToChannelOrRecover[T any](dataChan_ chan T, wsResp T) { +func SafeSend[T any](dataChan_ chan T, wsResp T) { defer func() { _ = recover() }() diff --git a/graphql/subscription.go b/graphql/subscription.go index b7c42a01..6c241e67 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -2,6 +2,7 @@ package graphql import ( "fmt" + "reflect" "sync" ) @@ -55,3 +56,23 @@ func (s *subscriptionMap) Delete(subscriptionID string) { defer s.Unlock() delete(s.map_, subscriptionID) } + +func (s *subscriptionMap) GetOrClose(subscriptionID string, subscriptionType string) (*subscription, error) { + s.Lock() + defer s.Unlock() + sub, success := s.map_[subscriptionID] + if !success { + return nil, fmt.Errorf("received message for unknown subscription ID '%s'", subscriptionID) + } + if sub.hasBeenUnsubscribed { + return nil, nil + } + if subscriptionType == webSocketTypeComplete { + sub.hasBeenUnsubscribed = true + s.map_[subscriptionID] = sub + reflect.ValueOf(sub.interfaceChan).Close() + return nil, nil + } + + return &sub, nil +} diff --git a/graphql/websocket.go b/graphql/websocket.go index 462894ae..abc63295 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "net/http" - "reflect" "strings" "sync" "sync/atomic" @@ -131,26 +130,6 @@ func (w *webSocketClient) listenWebSocket() { } } -func (w *webSocketClient) getSubscriptionOrHandleComplete(subscriptionID string, subscriptionType string) (*subscription, error) { - w.subscriptions.Lock() - defer w.subscriptions.Unlock() - sub, success := w.subscriptions.map_[subscriptionID] - if !success { - return nil, fmt.Errorf("received message for unknown subscription ID '%s'", subscriptionID) - } - if sub.hasBeenUnsubscribed { - return nil, nil - } - if subscriptionType == webSocketTypeComplete { - sub.hasBeenUnsubscribed = true - w.subscriptions.map_[subscriptionID] = sub - reflect.ValueOf(sub.interfaceChan).Close() - return nil, nil - } - - return &sub, nil -} - func (w *webSocketClient) forwardWebSocketData(message []byte) error { var wsMsg webSocketReceiveMessage err := json.Unmarshal(message, &wsMsg) @@ -160,7 +139,7 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if wsMsg.ID == "" { // e.g. keep-alive messages return nil } - sub, err := w.getSubscriptionOrHandleComplete(wsMsg.ID, wsMsg.Type) + sub, err := w.subscriptions.GetOrClose(wsMsg.ID, wsMsg.Type) if err != nil { return err } diff --git a/internal/integration/generated.go b/internal/integration/generated.go index e72a817b..68573354 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -3156,7 +3156,7 @@ func countForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) err if !ok { return errors.New("failed to cast interface into 'chan countWsResponse'") } - graphql.WriteToChannelOrRecover(dataChan_, wsResp) + graphql.SafeSend(dataChan_, wsResp) return nil } @@ -3204,7 +3204,7 @@ func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMe if !ok { return errors.New("failed to cast interface into 'chan countAuthorizedWsResponse'") } - graphql.WriteToChannelOrRecover(dataChan_, wsResp) + graphql.SafeSend(dataChan_, wsResp) return nil } @@ -3252,7 +3252,7 @@ func countCloseForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage if !ok { return errors.New("failed to cast interface into 'chan countCloseWsResponse'") } - graphql.WriteToChannelOrRecover(dataChan_, wsResp) + graphql.SafeSend(dataChan_, wsResp) return nil }