Skip to content

Commit f8d83f2

Browse files
committed
Use JWT_IDENTITY_CLAIM for encoding too
1 parent df68d65 commit f8d83f2

File tree

5 files changed

+68
-43
lines changed

5 files changed

+68
-43
lines changed

docs/options.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ General Options:
3030
such as ``RS*`` or ``ES*``. PEM format expected.
3131
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
3232
such as ``RS*`` or ``ES*``. PEM format expected.
33-
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity.
33+
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity.
3434
For interoperativity, the JWT RFC recommends using ``'sub'``.
3535
Defaults to ``'identity'``.
3636
================================= =========================================

flask_jwt_extended/jwt_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None):
321321
secret=config.encode_key,
322322
algorithm=config.algorithm,
323323
expires_delta=expires_delta,
324-
csrf=config.csrf_protect
324+
csrf=config.csrf_protect,
325+
identity_claim=config.identity_claim
325326
)
326327
return refresh_token
327328

@@ -354,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
354355
expires_delta=expires_delta,
355356
fresh=fresh,
356357
user_claims=self._user_claims_callback(identity),
357-
csrf=config.csrf_protect
358+
csrf=config.csrf_protect,
359+
identity_claim=config.identity_claim
358360
)
359361
return access_token
360362

flask_jwt_extended/tokens.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):
2525

2626

2727
def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
28-
user_claims, csrf):
28+
user_claims, csrf, identity_claim):
2929
"""
3030
Creates a new encoded (utf-8) access token.
3131
@@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
4040
be json serializable
4141
:param csrf: Whether to include a csrf double submit claim in this token
4242
(boolean)
43+
:param identity_claim: Which claim should be used to store the identity in
4344
:return: Encoded access token
4445
"""
4546
# Create the jwt
4647
token_data = {
47-
'identity': identity,
48+
identity_claim: identity,
4849
'fresh': fresh,
4950
'type': 'access',
5051
'user_claims': user_claims,
@@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
5455
return _encode_jwt(token_data, expires_delta, secret, algorithm)
5556

5657

57-
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
58+
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
5859
"""
5960
Creates a new encoded (utf-8) refresh token.
6061
@@ -65,10 +66,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
6566
(datetime.timedelta)
6667
:param csrf: Whether to include a csrf double submit claim in this token
6768
(boolean)
69+
:param identity_claim: Which claim should be used to store the identity in
6870
:return: Encoded refresh token
6971
"""
7072
token_data = {
71-
'identity': identity,
73+
identity_claim: identity,
7274
'type': 'refresh',
7375
}
7476
if csrf:

tests/test_jwt_encode_decode.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def test_encode_access_token(self):
3535
with self.app.test_request_context():
3636
identity = 'user1'
3737
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
38-
fresh=True, user_claims=user_claims, csrf=False)
38+
fresh=True, user_claims=user_claims, csrf=False,
39+
identity_claim='identity')
3940
data = jwt.decode(token, secret, algorithms=[algorithm])
4041
self.assertIn('exp', data)
4142
self.assertIn('iat', data)
@@ -59,7 +60,8 @@ def test_encode_access_token(self):
5960
# Check with a non-fresh token
6061
identity = 12345 # identity can be anything json serializable
6162
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
62-
fresh=False, user_claims=user_claims, csrf=True)
63+
fresh=False, user_claims=user_claims, csrf=True,
64+
identity_claim='identity')
6365
data = jwt.decode(token, secret, algorithms=[algorithm])
6466
self.assertIn('exp', data)
6567
self.assertIn('iat', data)
@@ -87,33 +89,35 @@ def test_encode_invalid_access_token(self):
8789
with self.assertRaises(Exception):
8890
encode_access_token('user1', 'secret', 'HS256',
8991
timedelta(hours=1), True, user_claims,
90-
csrf=True)
92+
csrf=True, identity_claim='identity')
9193

9294
user_claims = {'foo': timedelta(hours=4)}
9395
with self.assertRaises(Exception):
9496
encode_access_token('user1', 'secret', 'HS256',
9597
timedelta(hours=1), True, user_claims,
96-
csrf=True)
98+
csrf=True, identity_claim='identity')
9799

98100
def test_encode_refresh_token(self):
99101
secret = 'super-totally-secret-key'
100102
algorithm = 'HS256'
101103
token_expire_delta = timedelta(minutes=5)
104+
identity_claim = 'sub'
102105

103106
# Check with a fresh token
104107
with self.app.test_request_context():
105108
identity = 'user1'
106109
token = encode_refresh_token(identity, secret, algorithm,
107-
token_expire_delta, csrf=False)
110+
token_expire_delta, csrf=False,
111+
identity_claim=identity_claim)
108112
data = jwt.decode(token, secret, algorithms=[algorithm])
109113
self.assertIn('exp', data)
110114
self.assertIn('iat', data)
111115
self.assertIn('nbf', data)
112116
self.assertIn('jti', data)
113117
self.assertIn('type', data)
114-
self.assertIn('identity', data)
118+
self.assertIn(identity_claim, data)
115119
self.assertNotIn('csrf', data)
116-
self.assertEqual(data['identity'], identity)
120+
self.assertEqual(data[identity_claim], identity)
117121
self.assertEqual(data['type'], 'refresh')
118122
self.assertEqual(data['iat'], data['nbf'])
119123
now_ts = calendar.timegm(datetime.utcnow().utctimetuple())
@@ -124,16 +128,17 @@ def test_encode_refresh_token(self):
124128
# Check with a csrf token
125129
identity = 12345 # identity can be anything json serializable
126130
token = encode_refresh_token(identity, secret, algorithm,
127-
token_expire_delta, csrf=True)
131+
token_expire_delta, csrf=True,
132+
identity_claim=identity_claim)
128133
data = jwt.decode(token, secret, algorithms=[algorithm])
129134
self.assertIn('exp', data)
130135
self.assertIn('iat', data)
131136
self.assertIn('nbf', data)
132137
self.assertIn('jti', data)
133138
self.assertIn('type', data)
134139
self.assertIn('csrf', data)
135-
self.assertIn('identity', data)
136-
self.assertEqual(data['identity'], identity)
140+
self.assertIn(identity_claim, data)
141+
self.assertEqual(data[identity_claim], identity)
137142
self.assertEqual(data['type'], 'refresh')
138143
self.assertEqual(data['iat'], data['nbf'])
139144
now_ts = calendar.timegm(datetime.utcnow().utctimetuple())
@@ -142,6 +147,7 @@ def test_encode_refresh_token(self):
142147
self.assertGreater(exp_seconds, 60 * 4)
143148

144149
def test_decode_jwt(self):
150+
identity_claim = 'sub'
145151
# Test decoding a valid access token
146152
with self.app.test_request_context():
147153
now = datetime.utcnow()
@@ -151,26 +157,27 @@ def test_decode_jwt(self):
151157
'iat': now,
152158
'nbf': now,
153159
'jti': 'banana',
154-
'identity': 'banana',
160+
identity_claim: 'banana',
155161
'fresh': True,
156162
'type': 'access',
157163
'user_claims': {'foo': 'bar'},
158164
}
159165
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
160-
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
166+
data = decode_jwt(encoded_token, 'secret', 'HS256',
167+
csrf=False, identity_claim=identity_claim)
161168
self.assertIn('exp', data)
162169
self.assertIn('iat', data)
163170
self.assertIn('nbf', data)
164171
self.assertIn('jti', data)
165-
self.assertIn('identity', data)
172+
self.assertIn(identity_claim, data)
166173
self.assertIn('fresh', data)
167174
self.assertIn('type', data)
168175
self.assertIn('user_claims', data)
169176
self.assertEqual(data['exp'], now_ts + (5 * 60))
170177
self.assertEqual(data['iat'], now_ts)
171178
self.assertEqual(data['nbf'], now_ts)
172179
self.assertEqual(data['jti'], 'banana')
173-
self.assertEqual(data['identity'], 'banana')
180+
self.assertEqual(data[identity_claim], 'banana')
174181
self.assertEqual(data['fresh'], True)
175182
self.assertEqual(data['type'], 'access')
176183
self.assertEqual(data['user_claims'], {'foo': 'bar'})
@@ -184,22 +191,23 @@ def test_decode_jwt(self):
184191
'iat': now,
185192
'nbf': now,
186193
'jti': 'banana',
187-
'identity': 'banana',
194+
identity_claim: 'banana',
188195
'type': 'refresh',
189196
}
190197
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
191-
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
198+
data = decode_jwt(encoded_token, 'secret', 'HS256',
199+
csrf=False, identity_claim=identity_claim)
192200
self.assertIn('exp', data)
193201
self.assertIn('iat', data)
194202
self.assertIn('nbf', data)
195203
self.assertIn('jti', data)
196-
self.assertIn('identity', data)
204+
self.assertIn(identity_claim, data)
197205
self.assertIn('type', data)
198206
self.assertEqual(data['exp'], now_ts + (5 * 60))
199207
self.assertEqual(data['iat'], now_ts)
200208
self.assertEqual(data['nbf'], now_ts)
201209
self.assertEqual(data['jti'], 'banana')
202-
self.assertEqual(data['identity'], 'banana')
210+
self.assertEqual(data[identity_claim], 'banana')
203211
self.assertEqual(data['type'], 'refresh')
204212

205213
def test_decode_invalid_jwt(self):
@@ -210,7 +218,8 @@ def test_decode_invalid_jwt(self):
210218
'exp': datetime.utcnow() - timedelta(minutes=5),
211219
}
212220
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
213-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
221+
decode_jwt(encoded_token, 'secret', 'HS256',
222+
csrf=False, identity_claim='identity')
214223

215224
# Missing jti
216225
with self.assertRaises(JWTDecodeError):
@@ -220,7 +229,8 @@ def test_decode_invalid_jwt(self):
220229
'type': 'refresh'
221230
}
222231
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
223-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
232+
decode_jwt(encoded_token, 'secret', 'HS256',
233+
csrf=False, identity_claim='identity')
224234

225235
# Missing identity
226236
with self.assertRaises(JWTDecodeError):
@@ -230,7 +240,8 @@ def test_decode_invalid_jwt(self):
230240
'type': 'refresh'
231241
}
232242
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
233-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
243+
decode_jwt(encoded_token, 'secret', 'HS256',
244+
csrf=False, identity_claim='identity')
234245

235246
# Non-matching identity claim
236247
with self.assertRaises(JWTDecodeError):
@@ -240,7 +251,8 @@ def test_decode_invalid_jwt(self):
240251
'type': 'refresh'
241252
}
242253
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
243-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub')
254+
decode_jwt(encoded_token, 'secret', 'HS256',
255+
csrf=False, identity_claim='sub')
244256

245257
# Missing type
246258
with self.assertRaises(JWTDecodeError):
@@ -250,7 +262,8 @@ def test_decode_invalid_jwt(self):
250262
'exp': datetime.utcnow() + timedelta(minutes=5),
251263
}
252264
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
253-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
265+
decode_jwt(encoded_token, 'secret', 'HS256',
266+
csrf=False, identity_claim='identity')
254267

255268
# Missing fresh in access token
256269
with self.assertRaises(JWTDecodeError):
@@ -262,7 +275,8 @@ def test_decode_invalid_jwt(self):
262275
'user_claims': {}
263276
}
264277
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
265-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
278+
decode_jwt(encoded_token, 'secret', 'HS256',
279+
csrf=False, identity_claim='identity')
266280

267281
# Missing user claims in access token
268282
with self.assertRaises(JWTDecodeError):
@@ -274,7 +288,8 @@ def test_decode_invalid_jwt(self):
274288
'fresh': True
275289
}
276290
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
277-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
291+
decode_jwt(encoded_token, 'secret', 'HS256',
292+
csrf=False, identity_claim='identity')
278293

279294
# Bad token type
280295
with self.assertRaises(JWTDecodeError):
@@ -287,7 +302,8 @@ def test_decode_invalid_jwt(self):
287302
'user_claims': 'banana'
288303
}
289304
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
290-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
305+
decode_jwt(encoded_token, 'secret', 'HS256',
306+
csrf=False, identity_claim='identity')
291307

292308
# Missing csrf in csrf enabled token
293309
with self.assertRaises(JWTDecodeError):
@@ -300,7 +316,8 @@ def test_decode_invalid_jwt(self):
300316
'user_claims': 'banana'
301317
}
302318
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
303-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity')
319+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True,
320+
identity_claim='identity')
304321

305322
def test_create_jwt_with_object(self):
306323
# Complex object to test building a JWT from. Normally if you are using
@@ -329,18 +346,19 @@ def user_identity_lookup(user):
329346

330347
# Create the token using the complex object
331348
with app.test_request_context():
349+
identity_claim = 'sub'
350+
app.config['JWT_IDENTITY_CLAIM'] = identity_claim
332351
user = TestUser(username='foo', roles=['bar', 'baz'])
333352
access_token = create_access_token(identity=user)
334353
refresh_token = create_refresh_token(identity=user)
335-
identity = 'identity'
336354

337355
# Decode the tokens and make sure the values are set properly
338356
access_token_data = decode_jwt(access_token, app.secret_key,
339357
app.config['JWT_ALGORITHM'], csrf=False,
340-
identity_claim=identity)
358+
identity_claim=identity_claim)
341359
refresh_token_data = decode_jwt(refresh_token, app.secret_key,
342360
app.config['JWT_ALGORITHM'], csrf=False,
343-
identity_claim=identity)
344-
self.assertEqual(access_token_data[identity], 'foo')
361+
identity_claim=identity_claim)
362+
self.assertEqual(access_token_data[identity_claim], 'foo')
345363
self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz'])
346-
self.assertEqual(refresh_token_data[identity], 'foo')
364+
self.assertEqual(refresh_token_data[identity_claim], 'foo')

tests/test_protected_endpoints.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def test_bad_tokens(self):
331331
# Test token that was signed with a different key
332332
with self.app.test_request_context():
333333
token = encode_access_token('foo', 'newsecret', 'HS256',
334-
timedelta(minutes=5), True, {}, csrf=False)
334+
timedelta(minutes=5), True, {}, csrf=False,
335+
identity_claim='identity')
335336
auth_header = "Bearer {}".format(token)
336337
response = self.client.get('/protected', headers={'Authorization': auth_header})
337338
data = json.loads(response.get_data(as_text=True))
@@ -397,7 +398,7 @@ def test_optional_jwt_bad_tokens(self):
397398
with self.app.test_request_context():
398399
token = encode_access_token('foo', 'newsecret', 'HS256',
399400
timedelta(minutes=5), True, {},
400-
csrf=False)
401+
csrf=False, identity_claim='identity')
401402
auth_header = "Bearer {}".format(token)
402403
response = self.client.get('/partially-protected',
403404
headers={'Authorization': auth_header})
@@ -584,7 +585,8 @@ def test_jwt_with_different_algorithm(self):
584585
expires_delta=timedelta(minutes=5),
585586
fresh=True,
586587
user_claims={},
587-
csrf=False
588+
csrf=False,
589+
identity_claim='identity'
588590
)
589591
status, data = self._jwt_get('/protected', access_token)
590592
self.assertEqual(status, 422)
@@ -600,7 +602,8 @@ def test_optional_jwt_with_different_algorithm(self):
600602
expires_delta=timedelta(minutes=5),
601603
fresh=True,
602604
user_claims={},
603-
csrf=False
605+
csrf=False,
606+
identity_claim='identity'
604607
)
605608
status, data = self._jwt_get('/partially-protected', access_token)
606609
self.assertEqual(status, 422)

0 commit comments

Comments
 (0)