Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 744d00d

Browse files
committedApr 4, 2025
Gracefully close websocket connections
1 parent 4ca0dd6 commit 744d00d

File tree

1 file changed

+80
-10
lines changed

1 file changed

+80
-10
lines changed
 

‎graphql/handler/transport/websocket.go

+80-10
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@ type (
5353
pingPongTicker *time.Ticker
5454
receivedPong bool
5555
exec graphql.GraphExecutor
56-
closed bool
5756
headers http.Header
5857

58+
closeTimeout time.Duration
59+
60+
serverClosed bool
61+
clientClosed bool
62+
63+
clientCloseReceiver chan struct{}
64+
5965
initPayload InitPayload
6066
}
6167

@@ -115,13 +121,14 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph
115121
}
116122

117123
conn := wsConnection{
118-
active: map[string]context.CancelFunc{},
119-
conn: ws,
120-
ctx: r.Context(),
121-
exec: exec,
122-
me: me,
123-
headers: r.Header,
124-
Websocket: t,
124+
active: map[string]context.CancelFunc{},
125+
conn: ws,
126+
ctx: r.Context(),
127+
exec: exec,
128+
me: me,
129+
headers: r.Header,
130+
Websocket: t,
131+
clientCloseReceiver: make(chan struct{}),
125132
}
126133

127134
if !conn.init() {
@@ -231,6 +238,11 @@ func (c *wsConnection) init() bool {
231238

232239
func (c *wsConnection) write(msg *message) {
233240
c.mu.Lock()
241+
// don't write anything to a closed / closing connection
242+
if c.serverClosed {
243+
c.mu.Unlock()
244+
return
245+
}
234246
c.handlePossibleError(c.me.Send(msg), false)
235247
c.mu.Unlock()
236248
}
@@ -283,6 +295,15 @@ func (c *wsConnection) run() {
283295
go c.closeOnCancel(ctx)
284296

285297
for {
298+
c.mu.Lock()
299+
// dont read any more messages if the server has already closed
300+
// the exception is the close message -- see below
301+
if c.serverClosed {
302+
c.mu.Unlock()
303+
return
304+
}
305+
c.mu.Unlock()
306+
286307
start := graphql.Now()
287308
m, err := c.me.NextMessage()
288309
if err != nil {
@@ -303,7 +324,18 @@ func (c *wsConnection) run() {
303324
if closer != nil {
304325
closer()
305326
}
327+
// possible to receive this close message if connection was not marked as closed in time
328+
// and this thread wins over the goroutine in closeOnCancel
329+
// or for early termination
330+
// notify the closeOnCancel loop
306331
case connectionCloseMessageType:
332+
c.mu.Lock()
333+
c.clientClosed = true
334+
// normal termination
335+
if c.serverClosed {
336+
c.clientCloseReceiver <- struct{}{}
337+
}
338+
c.mu.Unlock()
307339
c.close(websocket.CloseNormalClosure, "terminated")
308340
return
309341
case pingMessageType:
@@ -494,16 +526,54 @@ func (c *wsConnection) sendConnectionError(format string, args ...any) {
494526

495527
func (c *wsConnection) close(closeCode int, message string) {
496528
c.mu.Lock()
497-
if c.closed {
529+
// we already sent our close message and are waiting for the client to send its close message
530+
if c.serverClosed {
498531
c.mu.Unlock()
499532
return
500533
}
534+
535+
// initiate the graceful shutdown server side
501536
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
502537
for _, closer := range c.active {
503538
closer()
504539
}
505-
c.closed = true
540+
541+
c.serverClosed = true
542+
506543
c.mu.Unlock()
544+
545+
closeTimeout := c.closeTimeout
546+
if closeTimeout == 0 {
547+
closeTimeout = time.Second * 3
548+
}
549+
closeTimer := time.NewTimer(c.closeTimeout)
550+
551+
// start a new read loop that only processes close messages
552+
go func() {
553+
for {
554+
m, err := c.me.NextMessage()
555+
if err == nil && m.t == connectionCloseMessageType {
556+
c.clientCloseReceiver <- struct{}{}
557+
return
558+
}
559+
// either we get net.ErrClosed or some other error
560+
// either way, bail on the graceful shutdown as it's either
561+
// impossible or very likely not happening
562+
if err != nil {
563+
return
564+
}
565+
}
566+
}()
567+
568+
// wait for the client to send a close message or the timeout
569+
select {
570+
case <-c.clientCloseReceiver:
571+
c.mu.Lock()
572+
c.clientClosed = true
573+
c.mu.Unlock()
574+
case <-closeTimer.C:
575+
}
576+
507577
_ = c.conn.Close()
508578

509579
if c.CloseFunc != nil {

0 commit comments

Comments
 (0)
Please sign in to comment.