@@ -35,7 +35,8 @@ def test_encode_access_token(self):
35
35
with self .app .test_request_context ():
36
36
identity = 'user1'
37
37
token = encode_access_token (identity , secret , algorithm , token_expire_delta ,
38
- fresh = True , user_claims = user_claims , csrf = False )
38
+ fresh = True , user_claims = user_claims , csrf = False ,
39
+ identity_claim = 'identity' )
39
40
data = jwt .decode (token , secret , algorithms = [algorithm ])
40
41
self .assertIn ('exp' , data )
41
42
self .assertIn ('iat' , data )
@@ -59,7 +60,8 @@ def test_encode_access_token(self):
59
60
# Check with a non-fresh token
60
61
identity = 12345 # identity can be anything json serializable
61
62
token = encode_access_token (identity , secret , algorithm , token_expire_delta ,
62
- fresh = False , user_claims = user_claims , csrf = True )
63
+ fresh = False , user_claims = user_claims , csrf = True ,
64
+ identity_claim = 'identity' )
63
65
data = jwt .decode (token , secret , algorithms = [algorithm ])
64
66
self .assertIn ('exp' , data )
65
67
self .assertIn ('iat' , data )
@@ -87,33 +89,35 @@ def test_encode_invalid_access_token(self):
87
89
with self .assertRaises (Exception ):
88
90
encode_access_token ('user1' , 'secret' , 'HS256' ,
89
91
timedelta (hours = 1 ), True , user_claims ,
90
- csrf = True )
92
+ csrf = True , identity_claim = 'identity' )
91
93
92
94
user_claims = {'foo' : timedelta (hours = 4 )}
93
95
with self .assertRaises (Exception ):
94
96
encode_access_token ('user1' , 'secret' , 'HS256' ,
95
97
timedelta (hours = 1 ), True , user_claims ,
96
- csrf = True )
98
+ csrf = True , identity_claim = 'identity' )
97
99
98
100
def test_encode_refresh_token (self ):
99
101
secret = 'super-totally-secret-key'
100
102
algorithm = 'HS256'
101
103
token_expire_delta = timedelta (minutes = 5 )
104
+ identity_claim = 'sub'
102
105
103
106
# Check with a fresh token
104
107
with self .app .test_request_context ():
105
108
identity = 'user1'
106
109
token = encode_refresh_token (identity , secret , algorithm ,
107
- token_expire_delta , csrf = False )
110
+ token_expire_delta , csrf = False ,
111
+ identity_claim = identity_claim )
108
112
data = jwt .decode (token , secret , algorithms = [algorithm ])
109
113
self .assertIn ('exp' , data )
110
114
self .assertIn ('iat' , data )
111
115
self .assertIn ('nbf' , data )
112
116
self .assertIn ('jti' , data )
113
117
self .assertIn ('type' , data )
114
- self .assertIn ('identity' , data )
118
+ self .assertIn (identity_claim , data )
115
119
self .assertNotIn ('csrf' , data )
116
- self .assertEqual (data ['identity' ], identity )
120
+ self .assertEqual (data [identity_claim ], identity )
117
121
self .assertEqual (data ['type' ], 'refresh' )
118
122
self .assertEqual (data ['iat' ], data ['nbf' ])
119
123
now_ts = calendar .timegm (datetime .utcnow ().utctimetuple ())
@@ -124,16 +128,17 @@ def test_encode_refresh_token(self):
124
128
# Check with a csrf token
125
129
identity = 12345 # identity can be anything json serializable
126
130
token = encode_refresh_token (identity , secret , algorithm ,
127
- token_expire_delta , csrf = True )
131
+ token_expire_delta , csrf = True ,
132
+ identity_claim = identity_claim )
128
133
data = jwt .decode (token , secret , algorithms = [algorithm ])
129
134
self .assertIn ('exp' , data )
130
135
self .assertIn ('iat' , data )
131
136
self .assertIn ('nbf' , data )
132
137
self .assertIn ('jti' , data )
133
138
self .assertIn ('type' , data )
134
139
self .assertIn ('csrf' , data )
135
- self .assertIn ('identity' , data )
136
- self .assertEqual (data ['identity' ], identity )
140
+ self .assertIn (identity_claim , data )
141
+ self .assertEqual (data [identity_claim ], identity )
137
142
self .assertEqual (data ['type' ], 'refresh' )
138
143
self .assertEqual (data ['iat' ], data ['nbf' ])
139
144
now_ts = calendar .timegm (datetime .utcnow ().utctimetuple ())
@@ -142,6 +147,7 @@ def test_encode_refresh_token(self):
142
147
self .assertGreater (exp_seconds , 60 * 4 )
143
148
144
149
def test_decode_jwt (self ):
150
+ identity_claim = 'sub'
145
151
# Test decoding a valid access token
146
152
with self .app .test_request_context ():
147
153
now = datetime .utcnow ()
@@ -151,26 +157,27 @@ def test_decode_jwt(self):
151
157
'iat' : now ,
152
158
'nbf' : now ,
153
159
'jti' : 'banana' ,
154
- 'identity' : 'banana' ,
160
+ identity_claim : 'banana' ,
155
161
'fresh' : True ,
156
162
'type' : 'access' ,
157
163
'user_claims' : {'foo' : 'bar' },
158
164
}
159
165
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
160
- data = decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
166
+ data = decode_jwt (encoded_token , 'secret' , 'HS256' ,
167
+ csrf = False , identity_claim = identity_claim )
161
168
self .assertIn ('exp' , data )
162
169
self .assertIn ('iat' , data )
163
170
self .assertIn ('nbf' , data )
164
171
self .assertIn ('jti' , data )
165
- self .assertIn ('identity' , data )
172
+ self .assertIn (identity_claim , data )
166
173
self .assertIn ('fresh' , data )
167
174
self .assertIn ('type' , data )
168
175
self .assertIn ('user_claims' , data )
169
176
self .assertEqual (data ['exp' ], now_ts + (5 * 60 ))
170
177
self .assertEqual (data ['iat' ], now_ts )
171
178
self .assertEqual (data ['nbf' ], now_ts )
172
179
self .assertEqual (data ['jti' ], 'banana' )
173
- self .assertEqual (data ['identity' ], 'banana' )
180
+ self .assertEqual (data [identity_claim ], 'banana' )
174
181
self .assertEqual (data ['fresh' ], True )
175
182
self .assertEqual (data ['type' ], 'access' )
176
183
self .assertEqual (data ['user_claims' ], {'foo' : 'bar' })
@@ -184,22 +191,23 @@ def test_decode_jwt(self):
184
191
'iat' : now ,
185
192
'nbf' : now ,
186
193
'jti' : 'banana' ,
187
- 'identity' : 'banana' ,
194
+ identity_claim : 'banana' ,
188
195
'type' : 'refresh' ,
189
196
}
190
197
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
191
- data = decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
198
+ data = decode_jwt (encoded_token , 'secret' , 'HS256' ,
199
+ csrf = False , identity_claim = identity_claim )
192
200
self .assertIn ('exp' , data )
193
201
self .assertIn ('iat' , data )
194
202
self .assertIn ('nbf' , data )
195
203
self .assertIn ('jti' , data )
196
- self .assertIn ('identity' , data )
204
+ self .assertIn (identity_claim , data )
197
205
self .assertIn ('type' , data )
198
206
self .assertEqual (data ['exp' ], now_ts + (5 * 60 ))
199
207
self .assertEqual (data ['iat' ], now_ts )
200
208
self .assertEqual (data ['nbf' ], now_ts )
201
209
self .assertEqual (data ['jti' ], 'banana' )
202
- self .assertEqual (data ['identity' ], 'banana' )
210
+ self .assertEqual (data [identity_claim ], 'banana' )
203
211
self .assertEqual (data ['type' ], 'refresh' )
204
212
205
213
def test_decode_invalid_jwt (self ):
@@ -210,7 +218,8 @@ def test_decode_invalid_jwt(self):
210
218
'exp' : datetime .utcnow () - timedelta (minutes = 5 ),
211
219
}
212
220
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
213
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
221
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
222
+ csrf = False , identity_claim = 'identity' )
214
223
215
224
# Missing jti
216
225
with self .assertRaises (JWTDecodeError ):
@@ -220,7 +229,8 @@ def test_decode_invalid_jwt(self):
220
229
'type' : 'refresh'
221
230
}
222
231
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
223
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
232
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
233
+ csrf = False , identity_claim = 'identity' )
224
234
225
235
# Missing identity
226
236
with self .assertRaises (JWTDecodeError ):
@@ -230,7 +240,8 @@ def test_decode_invalid_jwt(self):
230
240
'type' : 'refresh'
231
241
}
232
242
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
233
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
243
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
244
+ csrf = False , identity_claim = 'identity' )
234
245
235
246
# Non-matching identity claim
236
247
with self .assertRaises (JWTDecodeError ):
@@ -240,7 +251,8 @@ def test_decode_invalid_jwt(self):
240
251
'type' : 'refresh'
241
252
}
242
253
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
243
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'sub' )
254
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
255
+ csrf = False , identity_claim = 'sub' )
244
256
245
257
# Missing type
246
258
with self .assertRaises (JWTDecodeError ):
@@ -250,7 +262,8 @@ def test_decode_invalid_jwt(self):
250
262
'exp' : datetime .utcnow () + timedelta (minutes = 5 ),
251
263
}
252
264
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
253
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
265
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
266
+ csrf = False , identity_claim = 'identity' )
254
267
255
268
# Missing fresh in access token
256
269
with self .assertRaises (JWTDecodeError ):
@@ -262,7 +275,8 @@ def test_decode_invalid_jwt(self):
262
275
'user_claims' : {}
263
276
}
264
277
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
265
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
278
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
279
+ csrf = False , identity_claim = 'identity' )
266
280
267
281
# Missing user claims in access token
268
282
with self .assertRaises (JWTDecodeError ):
@@ -274,7 +288,8 @@ def test_decode_invalid_jwt(self):
274
288
'fresh' : True
275
289
}
276
290
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
277
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
291
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
292
+ csrf = False , identity_claim = 'identity' )
278
293
279
294
# Bad token type
280
295
with self .assertRaises (JWTDecodeError ):
@@ -287,7 +302,8 @@ def test_decode_invalid_jwt(self):
287
302
'user_claims' : 'banana'
288
303
}
289
304
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
290
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = False , identity_claim = 'identity' )
305
+ decode_jwt (encoded_token , 'secret' , 'HS256' ,
306
+ csrf = False , identity_claim = 'identity' )
291
307
292
308
# Missing csrf in csrf enabled token
293
309
with self .assertRaises (JWTDecodeError ):
@@ -300,7 +316,8 @@ def test_decode_invalid_jwt(self):
300
316
'user_claims' : 'banana'
301
317
}
302
318
encoded_token = jwt .encode (token_data , 'secret' , 'HS256' ).decode ('utf-8' )
303
- decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = True , identity_claim = 'identity' )
319
+ decode_jwt (encoded_token , 'secret' , 'HS256' , csrf = True ,
320
+ identity_claim = 'identity' )
304
321
305
322
def test_create_jwt_with_object (self ):
306
323
# Complex object to test building a JWT from. Normally if you are using
@@ -329,18 +346,19 @@ def user_identity_lookup(user):
329
346
330
347
# Create the token using the complex object
331
348
with app .test_request_context ():
349
+ identity_claim = 'sub'
350
+ app .config ['JWT_IDENTITY_CLAIM' ] = identity_claim
332
351
user = TestUser (username = 'foo' , roles = ['bar' , 'baz' ])
333
352
access_token = create_access_token (identity = user )
334
353
refresh_token = create_refresh_token (identity = user )
335
- identity = 'identity'
336
354
337
355
# Decode the tokens and make sure the values are set properly
338
356
access_token_data = decode_jwt (access_token , app .secret_key ,
339
357
app .config ['JWT_ALGORITHM' ], csrf = False ,
340
- identity_claim = identity )
358
+ identity_claim = identity_claim )
341
359
refresh_token_data = decode_jwt (refresh_token , app .secret_key ,
342
360
app .config ['JWT_ALGORITHM' ], csrf = False ,
343
- identity_claim = identity )
344
- self .assertEqual (access_token_data [identity ], 'foo' )
361
+ identity_claim = identity_claim )
362
+ self .assertEqual (access_token_data [identity_claim ], 'foo' )
345
363
self .assertEqual (access_token_data ['user_claims' ]['roles' ], ['bar' , 'baz' ])
346
- self .assertEqual (refresh_token_data [identity ], 'foo' )
364
+ self .assertEqual (refresh_token_data [identity_claim ], 'foo' )
0 commit comments