Skip to content

Commit 1b4461d

Browse files
[GH-288] Add Support for WithSynchronousRefresh Option in CachingProvider for Blocking/Non-Blocking Key Refresh (#314)
1 parent f0aafb9 commit 1b4461d

File tree

3 files changed

+139
-22
lines changed

3 files changed

+139
-22
lines changed

error_handler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)
2828
// DefaultErrorHandler is the default error handler implementation for the
2929
// JWTMiddleware. If an error handler is not provided via the WithErrorHandler
3030
// option this will be used.
31-
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
31+
func DefaultErrorHandler(w http.ResponseWriter, _ *http.Request, err error) {
3232
w.Header().Set("Content-Type", "application/json")
3333

3434
switch {

jwks/provider.go

+64-21
Original file line numberDiff line numberDiff line change
@@ -102,32 +102,58 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
102102
// in the background and the existing cached JWKS will be returned until the
103103
// JWKS cache is updated, or if the request errors then it will be evicted from
104104
// the cache.
105+
// The cache is keyed by the issuer's hostname. The synchronousRefresh
106+
// field determines whether the refresh is done synchronously or asynchronously.
107+
// This can be set using the WithSynchronousRefresh option.
105108
type CachingProvider struct {
106109
*Provider
107-
CacheTTL time.Duration
108-
mu sync.RWMutex
109-
cache map[string]cachedJWKS
110-
sem semaphore.Weighted
110+
CacheTTL time.Duration
111+
mu sync.RWMutex
112+
cache map[string]cachedJWKS
113+
sem *semaphore.Weighted
114+
synchronousRefresh bool
111115
}
112116

113117
type cachedJWKS struct {
114118
jwks *jose.JSONWebKeySet
115119
expiresAt time.Time
116120
}
117121

122+
type CachingProviderOption func(*CachingProvider)
123+
118124
// NewCachingProvider builds and returns a new CachingProvider.
119125
// If cacheTTL is zero then a default value of 1 minute will be used.
120-
func NewCachingProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...ProviderOption) *CachingProvider {
126+
func NewCachingProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...interface{}) *CachingProvider {
121127
if cacheTTL == 0 {
122128
cacheTTL = 1 * time.Minute
123129
}
124130

125-
return &CachingProvider{
126-
Provider: NewProvider(issuerURL, opts...),
127-
CacheTTL: cacheTTL,
128-
cache: map[string]cachedJWKS{},
129-
sem: *semaphore.NewWeighted(1),
131+
var providerOpts []ProviderOption
132+
var cachingOpts []CachingProviderOption
133+
134+
for _, opt := range opts {
135+
switch o := opt.(type) {
136+
case ProviderOption:
137+
providerOpts = append(providerOpts, o)
138+
case CachingProviderOption:
139+
cachingOpts = append(cachingOpts, o)
140+
default:
141+
panic(fmt.Sprintf("invalid option type: %T", o))
142+
}
143+
}
144+
cp := &CachingProvider{
145+
Provider: NewProvider(issuerURL, providerOpts...),
146+
CacheTTL: cacheTTL,
147+
cache: map[string]cachedJWKS{},
148+
sem: semaphore.NewWeighted(1),
149+
synchronousRefresh: false,
130150
}
151+
152+
for _, opt := range cachingOpts {
153+
opt(cp)
154+
}
155+
156+
return cp
131157
}
132158

133159
// KeyFunc adheres to the keyFunc signature that the Validator requires.
@@ -140,18 +166,26 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
140166

141167
if cached, ok := c.cache[issuer]; ok {
142168
if time.Now().After(cached.expiresAt) && c.sem.TryAcquire(1) {
143-
go func() {
169+
if !c.synchronousRefresh {
170+
go func() {
171+
defer c.sem.Release(1)
172+
refreshCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
173+
defer cancel()
174+
_, err := c.refreshKey(refreshCtx, issuer)
175+
176+
if err != nil {
177+
c.mu.Lock()
178+
delete(c.cache, issuer)
179+
c.mu.Unlock()
180+
}
181+
}()
182+
c.mu.RUnlock()
183+
return cached.jwks, nil
184+
} else {
185+
c.mu.RUnlock()
144186
defer c.sem.Release(1)
145-
refreshCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
146-
defer cancel()
147-
_, err := c.refreshKey(refreshCtx, issuer)
148-
149-
if err != nil {
150-
c.mu.Lock()
151-
delete(c.cache, issuer)
152-
c.mu.Unlock()
153-
}
154-
}()
187+
return c.refreshKey(ctx, issuer)
188+
}
155189
}
156190
c.mu.RUnlock()
157191
return cached.jwks, nil
@@ -161,6 +195,15 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
161195
return c.refreshKey(ctx, issuer)
162196
}
163197

198+
// WithSynchronousRefresh sets whether the CachingProvider blocks on refresh.
199+
// If set to true, it will block and wait for the refresh to complete.
200+
// If set to false (default), it will return the cached JWKS and trigger a background refresh.
201+
func WithSynchronousRefresh(blocking bool) CachingProviderOption {
202+
return func(cp *CachingProvider) {
203+
cp.synchronousRefresh = blocking
204+
}
205+
}
206+
164207
func (c *CachingProvider) refreshKey(ctx context.Context, issuer string) (interface{}, error) {
165208
c.mu.Lock()
166209
defer c.mu.Unlock()

jwks/provider_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,80 @@ func Test_JWKSProvider(t *testing.T) {
240240
assert.Nil(t, cachedJWKS)
241241
}, 1*time.Second, 250*time.Millisecond, "JWKS did not get uncached")
242242
})
243+
t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with expired cache (WithSynchronousRefresh)", func(t *testing.T) {
244+
initialJWKS, err := generateJWKS()
245+
require.NoError(t, err)
246+
atomic.StoreInt32(&requestCount, 0)
247+
248+
provider := NewCachingProvider(testServerURL, 5*time.Minute, WithSynchronousRefresh(true))
249+
provider.cache[testServerURL.Hostname()] = cachedJWKS{
250+
jwks: initialJWKS,
251+
expiresAt: time.Now(),
252+
}
253+
254+
var wg sync.WaitGroup
255+
for i := 0; i < 50; i++ {
256+
wg.Add(1)
257+
go func() {
258+
_, _ = provider.KeyFunc(context.Background())
259+
wg.Done()
260+
}()
261+
}
262+
wg.Wait()
263+
time.Sleep(2 * time.Second)
264+
// No need for Eventually since we're not blocking on refresh.
265+
returnedJWKS, err := provider.KeyFunc(context.Background())
266+
require.NoError(t, err)
267+
assert.True(t, cmp.Equal(expectedJWKS, returnedJWKS))
268+
269+
// Non-blocking behavior may allow extra API calls before the cache updates.
270+
assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "only wanted 2 requests (well known and jwks), but we got %d requests", atomic.LoadInt32(&requestCount))
271+
})
272+
273+
t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with no cache (WithSynchronousRefresh)", func(t *testing.T) {
274+
provider := NewCachingProvider(testServerURL, 5*time.Minute, WithSynchronousRefresh(true))
275+
atomic.StoreInt32(&requestCount, 0)
276+
277+
var wg sync.WaitGroup
278+
for i := 0; i < 50; i++ {
279+
wg.Add(1)
280+
go func() {
281+
_, _ = provider.KeyFunc(context.Background())
282+
wg.Done()
283+
}()
284+
}
285+
wg.Wait()
286+
287+
assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "only wanted 2 requests (well known and jwks), but we got %d requests")
288+
})
289+
t.Run("It correctly applies both ProviderOptions and CachingProviderOptions when using the CachingProvider without breaking", func(t *testing.T) {
290+
issuerURL, _ := url.Parse("https://example.com")
291+
jwksURL, _ := url.Parse("https://example.com/jwks")
292+
customClient := &http.Client{Timeout: 10 * time.Second}
293+
294+
provider := NewCachingProvider(
295+
issuerURL,
296+
30*time.Second,
297+
WithCustomJWKSURI(jwksURL),
298+
WithCustomClient(customClient),
299+
WithSynchronousRefresh(true),
300+
)
301+
302+
assert.Equal(t, jwksURL, provider.CustomJWKSURI, "CustomJWKSURI should be set correctly")
303+
assert.Equal(t, customClient, provider.Client, "Custom HTTP client should be set correctly")
304+
assert.True(t, provider.synchronousRefresh, "Synchronous refresh should be enabled")
305+
})
306+
t.Run("It panics when an invalid option type is provided when using the CachingProvider", func(t *testing.T) {
307+
issuerURL, _ := url.Parse("https://example.com")
308+
309+
assert.Panics(t, func() {
310+
NewCachingProvider(
311+
issuerURL,
312+
30*time.Second,
313+
"invalid_option",
314+
)
315+
}, "Expected panic when passing an invalid option type")
316+
})
243317
}
244318

245319
func generateJWKS() (*jose.JSONWebKeySet, error) {

0 commit comments

Comments
 (0)