Skip to content

Commit e8dec25

Browse files
committed
ensure websockets persists until done on drain
add e2e for ws beyond queue drain; move sleep to appropriate loc add ref to go issue
1 parent 6265a8e commit e8dec25

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

Diff for: pkg/queue/sharedmain/handlers.go

+13
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"net"
2222
"net/http"
23+
"sync/atomic"
2324
"time"
2425

2526
"go.uber.org/zap"
@@ -43,6 +44,7 @@ func mainHandler(
4344
prober func() bool,
4445
stats *netstats.RequestStats,
4546
logger *zap.SugaredLogger,
47+
pendingRequests *atomic.Int32,
4648
) (http.Handler, *pkghandler.Drainer) {
4749
target := net.JoinHostPort("127.0.0.1", env.UserPort)
4850

@@ -86,6 +88,8 @@ func mainHandler(
8688

8789
composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)
8890

91+
composedHandler = withRequestCounter(composedHandler, pendingRequests)
92+
8993
drainer := &pkghandler.Drainer{
9094
QuietPeriod: drainSleepDuration,
9195
// Add Activator probe header to the drainer so it can handle probes directly from activator
@@ -100,6 +104,7 @@ func mainHandler(
100104
// Hence we need to have RequestLogHandler be the first one.
101105
composedHandler = requestLogHandler(logger, composedHandler, env)
102106
}
107+
103108
return composedHandler, drainer
104109
}
105110

@@ -139,3 +144,11 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo
139144
h.ServeHTTP(w, r)
140145
})
141146
}
147+
148+
func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler {
149+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
150+
pendingRequests.Add(1)
151+
defer pendingRequests.Add(-1)
152+
h.ServeHTTP(w, r)
153+
})
154+
}

Diff for: pkg/queue/sharedmain/main.go

+20-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"net/http"
2525
"os"
2626
"strconv"
27+
"sync/atomic"
2728
"time"
2829

2930
"github.com/kelseyhightower/envconfig"
@@ -169,6 +170,8 @@ func Main(opts ...Option) error {
169170
d := Defaults{
170171
Ctx: signals.NewContext(),
171172
}
173+
pendingRequests := atomic.Int32{}
174+
pendingRequests.Store(0)
172175

173176
// Parse the environment.
174177
var env config
@@ -234,7 +237,7 @@ func Main(opts ...Option) error {
234237
// Enable TLS when certificate is mounted.
235238
tlsEnabled := exists(logger, certPath) && exists(logger, keyPath)
236239

237-
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger)
240+
mainHandler, drainer := mainHandler(d.Ctx, env, d.Transport, probe, stats, logger, &pendingRequests)
238241
adminHandler := adminHandler(d.Ctx, logger, drainer)
239242

240243
// Enable TLS server when activator server certs are mounted.
@@ -304,8 +307,24 @@ func Main(opts ...Option) error {
304307
case <-d.Ctx.Done():
305308
logger.Info("Received TERM signal, attempting to gracefully shutdown servers.")
306309
logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration)
310+
time.Sleep(drainSleepDuration)
307311
drainer.Drain()
308312

313+
// Wait on active requests to complete. This is done explictly
314+
// to avoid closing any connections which have been highjacked,
315+
// as in net/http `.Shutdown` would do so ungracefully.
316+
// See https://github.com/golang/go/issues/17721
317+
ticker := time.NewTicker(1 * time.Second)
318+
defer ticker.Stop()
319+
logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load())
320+
WaitOnPendingRequests:
321+
for range ticker.C {
322+
if pendingRequests.Load() <= 0 {
323+
logger.Infof("Drain: all pending requests completed")
324+
break WaitOnPendingRequests
325+
}
326+
}
327+
309328
for name, srv := range httpServers {
310329
logger.Info("Shutting down server: ", name)
311330
if err := srv.Shutdown(context.Background()); err != nil {

Diff for: test/e2e/websocket_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ func TestWebSocketWithTimeout(t *testing.T) {
322322
idleTimeoutSeconds: 10,
323323
delay: "20",
324324
expectError: true,
325+
}, {
326+
name: "websocket does not drop after queue drain is called at 30s",
327+
timeoutSeconds: 60,
328+
delay: "45",
329+
expectError: false,
325330
}}
326331
for _, tc := range testCases {
327332
t.Run(tc.name, func(t *testing.T) {

0 commit comments

Comments
 (0)