@@ -53,9 +53,15 @@ type (
53
53
pingPongTicker * time.Ticker
54
54
receivedPong bool
55
55
exec graphql.GraphExecutor
56
- closed bool
57
56
headers http.Header
58
57
58
+ closeTimeout time.Duration
59
+
60
+ serverClosed bool
61
+ clientClosed bool
62
+
63
+ clientCloseReceiver chan struct {}
64
+
59
65
initPayload InitPayload
60
66
}
61
67
@@ -115,13 +121,14 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph
115
121
}
116
122
117
123
conn := wsConnection {
118
- active : map [string ]context.CancelFunc {},
119
- conn : ws ,
120
- ctx : r .Context (),
121
- exec : exec ,
122
- me : me ,
123
- headers : r .Header ,
124
- Websocket : t ,
124
+ active : map [string ]context.CancelFunc {},
125
+ conn : ws ,
126
+ ctx : r .Context (),
127
+ exec : exec ,
128
+ me : me ,
129
+ headers : r .Header ,
130
+ Websocket : t ,
131
+ clientCloseReceiver : make (chan struct {}),
125
132
}
126
133
127
134
if ! conn .init () {
@@ -231,6 +238,11 @@ func (c *wsConnection) init() bool {
231
238
232
239
func (c * wsConnection ) write (msg * message ) {
233
240
c .mu .Lock ()
241
+ // don't write anything to a closed / closing connection
242
+ if c .serverClosed {
243
+ c .mu .Unlock ()
244
+ return
245
+ }
234
246
c .handlePossibleError (c .me .Send (msg ), false )
235
247
c .mu .Unlock ()
236
248
}
@@ -283,6 +295,15 @@ func (c *wsConnection) run() {
283
295
go c .closeOnCancel (ctx )
284
296
285
297
for {
298
+ c .mu .Lock ()
299
+ // dont read any more messages if the server has already closed
300
+ // the exception is the close message -- see below
301
+ if c .serverClosed {
302
+ c .mu .Unlock ()
303
+ return
304
+ }
305
+ c .mu .Unlock ()
306
+
286
307
start := graphql .Now ()
287
308
m , err := c .me .NextMessage ()
288
309
if err != nil {
@@ -303,7 +324,18 @@ func (c *wsConnection) run() {
303
324
if closer != nil {
304
325
closer ()
305
326
}
327
+ // possible to receive this close message if connection was not marked as closed in time
328
+ // and this thread wins over the goroutine in closeOnCancel
329
+ // or for early termination
330
+ // notify the closeOnCancel loop
306
331
case connectionCloseMessageType :
332
+ c .mu .Lock ()
333
+ c .clientClosed = true
334
+ // normal termination
335
+ if c .serverClosed {
336
+ c .clientCloseReceiver <- struct {}{}
337
+ }
338
+ c .mu .Unlock ()
307
339
c .close (websocket .CloseNormalClosure , "terminated" )
308
340
return
309
341
case pingMessageType :
@@ -494,16 +526,54 @@ func (c *wsConnection) sendConnectionError(format string, args ...any) {
494
526
495
527
func (c * wsConnection ) close (closeCode int , message string ) {
496
528
c .mu .Lock ()
497
- if c .closed {
529
+ // we already sent our close message and are waiting for the client to send its close message
530
+ if c .serverClosed {
498
531
c .mu .Unlock ()
499
532
return
500
533
}
534
+
535
+ // initiate the graceful shutdown server side
501
536
_ = c .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (closeCode , message ))
502
537
for _ , closer := range c .active {
503
538
closer ()
504
539
}
505
- c .closed = true
540
+
541
+ c .serverClosed = true
542
+
506
543
c .mu .Unlock ()
544
+
545
+ closeTimeout := c .closeTimeout
546
+ if closeTimeout == 0 {
547
+ closeTimeout = time .Second * 3
548
+ }
549
+ closeTimer := time .NewTimer (c .closeTimeout )
550
+
551
+ // start a new read loop that only processes close messages
552
+ go func () {
553
+ for {
554
+ m , err := c .me .NextMessage ()
555
+ if err == nil && m .t == connectionCloseMessageType {
556
+ c .clientCloseReceiver <- struct {}{}
557
+ return
558
+ }
559
+ // either we get net.ErrClosed or some other error
560
+ // either way, bail on the graceful shutdown as it's either
561
+ // impossible or very likely not happening
562
+ if err != nil {
563
+ return
564
+ }
565
+ }
566
+ }()
567
+
568
+ // wait for the client to send a close message or the timeout
569
+ select {
570
+ case <- c .clientCloseReceiver :
571
+ c .mu .Lock ()
572
+ c .clientClosed = true
573
+ c .mu .Unlock ()
574
+ case <- closeTimer .C :
575
+ }
576
+
507
577
_ = c .conn .Close ()
508
578
509
579
if c .CloseFunc != nil {
0 commit comments