@@ -102,32 +102,58 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
102
102
// in the background and the existing cached JWKS will be returned until the
103
103
// JWKS cache is updated, or if the request errors then it will be evicted from
104
104
// the cache.
105
+ // The cache is keyed by the issuer's hostname. The synchronousRefresh
106
+ // field determines whether the refresh is done synchronously or asynchronously.
107
+ // This can be set using the WithSynchronousRefresh option.
105
108
type CachingProvider struct {
106
109
* Provider
107
- CacheTTL time.Duration
108
- mu sync.RWMutex
109
- cache map [string ]cachedJWKS
110
- sem semaphore.Weighted
110
+ CacheTTL time.Duration
111
+ mu sync.RWMutex
112
+ cache map [string ]cachedJWKS
113
+ sem * semaphore.Weighted
114
+ synchronousRefresh bool
111
115
}
112
116
113
117
type cachedJWKS struct {
114
118
jwks * jose.JSONWebKeySet
115
119
expiresAt time.Time
116
120
}
117
121
122
+ type CachingProviderOption func (* CachingProvider )
123
+
118
124
// NewCachingProvider builds and returns a new CachingProvider.
119
125
// If cacheTTL is zero then a default value of 1 minute will be used.
120
- func NewCachingProvider (issuerURL * url.URL , cacheTTL time.Duration , opts ... ProviderOption ) * CachingProvider {
126
+ func NewCachingProvider (issuerURL * url.URL , cacheTTL time.Duration , opts ... interface {} ) * CachingProvider {
121
127
if cacheTTL == 0 {
122
128
cacheTTL = 1 * time .Minute
123
129
}
124
130
125
- return & CachingProvider {
126
- Provider : NewProvider (issuerURL , opts ... ),
127
- CacheTTL : cacheTTL ,
128
- cache : map [string ]cachedJWKS {},
129
- sem : * semaphore .NewWeighted (1 ),
131
+ var providerOpts []ProviderOption
132
+ var cachingOpts []CachingProviderOption
133
+
134
+ for _ , opt := range opts {
135
+ switch o := opt .(type ) {
136
+ case ProviderOption :
137
+ providerOpts = append (providerOpts , o )
138
+ case CachingProviderOption :
139
+ cachingOpts = append (cachingOpts , o )
140
+ default :
141
+ panic (fmt .Sprintf ("invalid option type: %T" , o ))
142
+ }
143
+ }
144
+ cp := & CachingProvider {
145
+ Provider : NewProvider (issuerURL , providerOpts ... ),
146
+ CacheTTL : cacheTTL ,
147
+ cache : map [string ]cachedJWKS {},
148
+ sem : semaphore .NewWeighted (1 ),
149
+ synchronousRefresh : false ,
130
150
}
151
+
152
+ for _ , opt := range cachingOpts {
153
+ opt (cp )
154
+ }
155
+
156
+ return cp
131
157
}
132
158
133
159
// KeyFunc adheres to the keyFunc signature that the Validator requires.
@@ -140,18 +166,26 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
140
166
141
167
if cached , ok := c .cache [issuer ]; ok {
142
168
if time .Now ().After (cached .expiresAt ) && c .sem .TryAcquire (1 ) {
143
- go func () {
169
+ if ! c .synchronousRefresh {
170
+ go func () {
171
+ defer c .sem .Release (1 )
172
+ refreshCtx , cancel := context .WithTimeout (context .Background (), 15 * time .Second )
173
+ defer cancel ()
174
+ _ , err := c .refreshKey (refreshCtx , issuer )
175
+
176
+ if err != nil {
177
+ c .mu .Lock ()
178
+ delete (c .cache , issuer )
179
+ c .mu .Unlock ()
180
+ }
181
+ }()
182
+ c .mu .RUnlock ()
183
+ return cached .jwks , nil
184
+ } else {
185
+ c .mu .RUnlock ()
144
186
defer c .sem .Release (1 )
145
- refreshCtx , cancel := context .WithTimeout (context .Background (), 15 * time .Second )
146
- defer cancel ()
147
- _ , err := c .refreshKey (refreshCtx , issuer )
148
-
149
- if err != nil {
150
- c .mu .Lock ()
151
- delete (c .cache , issuer )
152
- c .mu .Unlock ()
153
- }
154
- }()
187
+ return c .refreshKey (ctx , issuer )
188
+ }
155
189
}
156
190
c .mu .RUnlock ()
157
191
return cached .jwks , nil
@@ -161,6 +195,15 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
161
195
return c .refreshKey (ctx , issuer )
162
196
}
163
197
198
+ // WithSynchronousRefresh sets whether the CachingProvider blocks on refresh.
199
+ // If set to true, it will block and wait for the refresh to complete.
200
+ // If set to false (default), it will return the cached JWKS and trigger a background refresh.
201
+ func WithSynchronousRefresh (blocking bool ) CachingProviderOption {
202
+ return func (cp * CachingProvider ) {
203
+ cp .synchronousRefresh = blocking
204
+ }
205
+ }
206
+
164
207
func (c * CachingProvider ) refreshKey (ctx context.Context , issuer string ) (interface {}, error ) {
165
208
c .mu .Lock ()
166
209
defer c .mu .Unlock ()
0 commit comments