From 95ff7234c8133b41a3f74137a3933b2036cc1c1e Mon Sep 17 00:00:00 2001 From: jeremylaier-tc <116579872+jeremylaier-tc@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:56:16 -0500 Subject: [PATCH] ws fix (#3383) --- graphql/handler/transport/websocket.go | 4 ++ graphql/handler/transport/websocket_test.go | 42 +++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 32e31c7c75d..d6c174cd1fb 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -54,6 +54,7 @@ type ( receivedPong bool exec graphql.GraphExecutor closed bool + headers http.Header initPayload InitPayload } @@ -119,6 +120,7 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph ctx: r.Context(), exec: exec, me: me, + headers: r.Header, Websocket: t, } @@ -387,6 +389,8 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) { End: graphql.Now(), } + params.Headers = c.headers + rc, err := c.exec.CreateOperationContext(ctx, params) if err != nil { resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err) diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index ec47c422d9e..cc02018053d 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -189,6 +189,48 @@ func TestWebsocketWithKeepAlive(t *testing.T) { assert.Equal(t, connectionKeepAliveMsg, msg.Type) } +func TestWebsocketWithPassedHeaders(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.Websocket{ + KeepAlivePingInterval: 100 * time.Millisecond, + }) + + h.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler { + assert.NotNil(t, graphql.GetOperationContext(ctx).Headers) + + return next(ctx) + }) + + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + assert.Equal(t, connectionAckMsg, readOp(c).Type) + assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) + + require.NoError(t, c.WriteJSON(&operationMessage{ + Type: startMsg, + ID: "test_1", + Payload: json.RawMessage(`{"query": "subscription { name }"}`), + })) + + // keepalive + msg := readOp(c) + assert.Equal(t, connectionKeepAliveMsg, msg.Type) + + // server message + h.SendNextSubscriptionMessage() + msg = readOp(c) + assert.Equal(t, dataMsg, msg.Type) + + // keepalive + msg = readOp(c) + assert.Equal(t, connectionKeepAliveMsg, msg.Type) +} + func TestWebsocketInitFunc(t *testing.T) { t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) { h := testserver.New()