Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gracefully close websocket connections #3638

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 86 additions & 11 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@
pingPongTicker *time.Ticker
receivedPong bool
exec graphql.GraphExecutor
closed bool
headers http.Header

closeTimeout time.Duration

serverClosed bool
clientClosed bool

clientCloseReceiver chan struct{}

initPayload InitPayload
}

Expand Down Expand Up @@ -115,13 +121,14 @@
}

conn := wsConnection{
active: map[string]context.CancelFunc{},
conn: ws,
ctx: r.Context(),
exec: exec,
me: me,
headers: r.Header,
Websocket: t,
active: map[string]context.CancelFunc{},
conn: ws,
ctx: r.Context(),
exec: exec,
me: me,
headers: r.Header,
Websocket: t,
clientCloseReceiver: make(chan struct{}),
}

if !conn.init() {
Expand Down Expand Up @@ -231,8 +238,12 @@

func (c *wsConnection) write(msg *message) {
c.mu.Lock()
defer c.mu.Unlock()
// don't write anything to a closed / closing connection
if c.serverClosed {
return
}
c.handlePossibleError(c.me.Send(msg), false)
c.mu.Unlock()
}

func (c *wsConnection) run() {
Expand Down Expand Up @@ -283,6 +294,16 @@
go c.closeOnCancel(ctx)

for {
c.mu.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have performance implications

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah totally. I'm open to suggestions as to how we should measure how much this impacts performance as this is a concern for me and my team as well.

Overall though - if the thesis of this PR is correct (that we are not gracefully handling shutdown), I'm inclined to trade perf off to get us closer to WS spec compliance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we have a solution that does both: doesn't impact perf and gracefully shuts down the connection. Again open to suggestions on how to refactor the implementation to do both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple ways to do that that would not require mutext in a loop, also i would check for server closed after c.me.NextMessage() call, because it would block on that until next message comes, so you probably wanna handle it as soon as you get any message and not on the next iteration

// dont read any more messages if the server has already closed
// the exception is the client close message for graceful shutdown
// or an early termination message from the client
if c.serverClosed {
c.mu.Unlock()
return
}
c.mu.Unlock()

start := graphql.Now()
m, err := c.me.NextMessage()
if err != nil {
Expand All @@ -303,7 +324,21 @@
if closer != nil {
closer()
}
// possible to receive this close message if connection was not marked as closed in time
// and this thread wins over the goroutine in closeOnCancel
// or for early termination
// notify the closeOnCancel loop
case connectionCloseMessageType:
c.mu.Lock()
c.clientClosed = true
// server already initiated the graceful shutdown
// we don't need to send another close message
if c.serverClosed {
c.clientCloseReceiver <- struct{}{}
c.mu.Unlock()
return
}
c.mu.Unlock()
c.close(websocket.CloseNormalClosure, "terminated")
return
case pingMessageType:
Expand Down Expand Up @@ -494,16 +529,56 @@

func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
if c.closed {
// we already sent our close message and are waiting for the client to send its close message
if c.serverClosed {
c.mu.Unlock()
return
}

// initiate the graceful shutdown server side
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.closed = true

c.serverClosed = true

c.mu.Unlock()

closeTimeout := c.closeTimeout
if closeTimeout == 0 {
closeTimeout = time.Second * 3

Check failure on line 550 in graphql/handler/transport/websocket.go

View workflow job for this annotation

GitHub Actions / golangci-lint (1.23)

ineffectual assignment to closeTimeout (ineffassign)
}
closeTimer := time.NewTimer(c.closeTimeout)

// start a new read loop that only processes close messages
go func() {
for {
m, err := c.me.NextMessage()
if err == nil && m.t == connectionCloseMessageType {
c.clientCloseReceiver <- struct{}{}
return
}
// either we get net.ErrClosed or some other error
// either way, bail on the graceful shutdown as it's either
// impossible or very likely not happening
// TODO: optimize this to bypass the select statement to avoid
// waiting the entire closeTimeout
if err != nil {
return
}
}
}()

// wait for the client to send a close message or the timeout
select {
case <-c.clientCloseReceiver:
c.mu.Lock()
c.clientClosed = true
c.mu.Unlock()
case <-closeTimer.C:
}

_ = c.conn.Close()

if c.CloseFunc != nil {
Expand Down
92 changes: 92 additions & 0 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,98 @@ func TestWebsocketWithPingPongInterval(t *testing.T) {
})
}

func TestWebsocketGracefulShutdown(t *testing.T) {
t.Run("server gracefully closes connection", func(t *testing.T) {
// Create a channel to track server-side events
serverEvents := make(chan string, 10)

// Create a context we can cancel to trigger shutdown
ctx, cancel := context.WithCancel(context.Background())

// Create a custom close handler to track when connection is fully closed
h := testserver.New()
h.AddTransport(transport.Websocket{
CloseFunc: func(ctx context.Context, closeCode int) {
serverEvents <- "connection_fully_closed"
},
})

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Intercept the handler to track when server initiates close
serverEvents <- "server_handling_request"

// Use our cancellable context
r = r.WithContext(ctx)

h.ServeHTTP(w, r)
}))
defer srv.Close()

// Connect client
c := wsConnect(srv.URL)
defer c.Close()

// Initialize connection
require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
assert.Equal(t, connectionAckMsg, readOp(c).Type)
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)

// Start a subscription to verify it gets terminated
require.NoError(t, c.WriteJSON(&operationMessage{
Type: startMsg,
ID: "test_sub",
Payload: json.RawMessage(`{"query": "subscription { name }"}`),
}))

// Trigger server shutdown by canceling the context
serverEvents <- "server_initiating_close"
cancel()

// Server should send close message
_, _, err := c.ReadMessage()
assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)

// Try to send another operation - server should ignore it
assert.Equal(t, c.WriteJSON(&operationMessage{
Type: startMsg,
ID: "ignored_operation",
Payload: json.RawMessage(`{"query": "query { name }"}`),
}), websocket.ErrCloseSent)

// Client acknowledges close
closeCode := websocket.CloseNormalClosure
closeText := "client acknowledging close"
// This should fail with a websocket.CloseError
// (but the close message is still actually sent under the hood)
assert.Equal(t, c.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeCode, closeText),
time.Now().Add(time.Second),
), websocket.ErrCloseSent)

// Verify server events happened in correct order
require.Equal(t, "server_handling_request", <-serverEvents)
require.Equal(t, "server_initiating_close", <-serverEvents)
require.Equal(t, "connection_fully_closed", <-serverEvents)

// Verify no more events
select {
case event := <-serverEvents:
assert.Fail(t, "Unexpected server event", event)
case <-time.After(50 * time.Millisecond):
// This is expected - no more events
}

// Verify the underlying connection is actually closed by attempting to read from it
// This should fail with a websocket.CloseError
_, _, err = c.ReadMessage()
require.Error(t, err)
closeErr, ok := err.(*websocket.CloseError)
require.True(t, ok, "Expected websocket.CloseError, got %T: %v", err, err)
assert.Equal(t, websocket.CloseNormalClosure, closeErr.Code)
})
}

func wsConnect(url string) *websocket.Conn {
return wsConnectWithSubprotocol(url, "")
}
Expand Down
Loading