Skip to content

Commit 4701f93

Browse files
authored
Merge pull request #67 from psafont/identity_fix
Fix regressions introduced in 3.1.0
2 parents 81a4363 + f150fe0 commit 4701f93

File tree

6 files changed

+48
-35
lines changed

6 files changed

+48
-35
lines changed

examples/database_blacklist/app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def login():
5656
refresh_token = create_refresh_token(identity=username)
5757

5858
# Store the tokens in our store with a status of not currently revoked.
59-
add_token_to_database(access_token)
60-
add_token_to_database(refresh_token)
59+
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
60+
add_token_to_database(refresh_token, app.config['JWT_IDENTITY_CLAIM'])
6161

6262
ret = {
6363
'access_token': access_token,
@@ -72,7 +72,7 @@ def refresh():
7272
# Do the same thing that we did in the login endpoint here
7373
current_user = get_jwt_identity()
7474
access_token = create_access_token(identity=current_user)
75-
add_token_to_database(access_token)
75+
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
7676
return jsonify({'access_token': access_token}), 201
7777

7878
# Provide a way for a user to look at their tokens

examples/database_blacklist/blacklist_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ def _epoch_utc_to_datetime(epoch_utc):
1616
return datetime.fromtimestamp(epoch_utc)
1717

1818

19-
def add_token_to_database(encoded_token):
19+
def add_token_to_database(encoded_token, identity_claim):
2020
"""
2121
Adds a new token to the database. It is not revoked when it is added.
22+
:param identity_claim:
2223
"""
2324
decoded_token = decode_token(encoded_token)
2425
jti = decoded_token['jti']
2526
token_type = decoded_token['type']
26-
user_identity = decoded_token['identity']
27+
user_identity = decoded_token[identity_claim]
2728
expires = _epoch_utc_to_datetime(decoded_token['exp'])
2829
revoked = False
2930

flask_jwt_extended/view_decorators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def jwt_required(fn):
3333
def wrapper(*args, **kwargs):
3434
jwt_data = _decode_jwt_from_request(request_type='access')
3535
ctx_stack.top.jwt = jwt_data
36-
_load_user(jwt_data['identity'])
36+
_load_user(jwt_data[config.identity_claim])
3737
return fn(*args, **kwargs)
3838
return wrapper
3939

@@ -53,7 +53,7 @@ def wrapper(*args, **kwargs):
5353
try:
5454
jwt_data = _decode_jwt_from_request(request_type='access')
5555
ctx_stack.top.jwt = jwt_data
56-
_load_user(jwt_data['identity'])
56+
_load_user(jwt_data[config.identity_claim])
5757
except NoAuthorizationError:
5858
pass
5959
return fn(*args, **kwargs)
@@ -77,7 +77,7 @@ def wrapper(*args, **kwargs):
7777
raise FreshTokenRequired('Fresh token required')
7878

7979
ctx_stack.top.jwt = jwt_data
80-
_load_user(jwt_data['identity'])
80+
_load_user(jwt_data[config.identity_claim])
8181
return fn(*args, **kwargs)
8282
return wrapper
8383

@@ -92,7 +92,7 @@ def jwt_refresh_token_required(fn):
9292
def wrapper(*args, **kwargs):
9393
jwt_data = _decode_jwt_from_request(request_type='refresh')
9494
ctx_stack.top.jwt = jwt_data
95-
_load_user(jwt_data['identity'])
95+
_load_user(jwt_data[config.identity_claim])
9696
return fn(*args, **kwargs)
9797
return wrapper
9898

tests/test_blacklist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def setUp(self):
1616
self.app = Flask(__name__)
1717
self.app.secret_key = 'super=secret'
1818
self.app.config['JWT_BLACKLIST_ENABLED'] = True
19+
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
1920
self.jwt_manager = JWTManager(self.app)
2021
self.client = self.app.test_client()
2122
self.blacklist = set()

tests/test_jwt_encode_decode.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,25 @@ def test_encode_access_token(self):
3030
algorithm = 'HS256'
3131
token_expire_delta = timedelta(minutes=5)
3232
user_claims = {'foo': 'bar'}
33+
identity_claim = 'identity'
3334

3435
# Check with a fresh token
3536
with self.app.test_request_context():
3637
identity = 'user1'
3738
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
3839
fresh=True, user_claims=user_claims, csrf=False,
39-
identity_claim='identity')
40+
identity_claim=identity_claim)
4041
data = jwt.decode(token, secret, algorithms=[algorithm])
4142
self.assertIn('exp', data)
4243
self.assertIn('iat', data)
4344
self.assertIn('nbf', data)
4445
self.assertIn('jti', data)
45-
self.assertIn('identity', data)
46+
self.assertIn(identity_claim, data)
4647
self.assertIn('fresh', data)
4748
self.assertIn('type', data)
4849
self.assertIn('user_claims', data)
4950
self.assertNotIn('csrf', data)
50-
self.assertEqual(data['identity'], identity)
51+
self.assertEqual(data[identity_claim], identity)
5152
self.assertEqual(data['fresh'], True)
5253
self.assertEqual(data['type'], 'access')
5354
self.assertEqual(data['user_claims'], user_claims)
@@ -61,18 +62,18 @@ def test_encode_access_token(self):
6162
identity = 12345 # identity can be anything json serializable
6263
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
6364
fresh=False, user_claims=user_claims, csrf=True,
64-
identity_claim='identity')
65+
identity_claim=identity_claim)
6566
data = jwt.decode(token, secret, algorithms=[algorithm])
6667
self.assertIn('exp', data)
6768
self.assertIn('iat', data)
6869
self.assertIn('nbf', data)
6970
self.assertIn('jti', data)
70-
self.assertIn('identity', data)
71+
self.assertIn(identity_claim, data)
7172
self.assertIn('fresh', data)
7273
self.assertIn('type', data)
7374
self.assertIn('user_claims', data)
7475
self.assertIn('csrf', data)
75-
self.assertEqual(data['identity'], identity)
76+
self.assertEqual(data[identity_claim], identity)
7677
self.assertEqual(data['fresh'], False)
7778
self.assertEqual(data['type'], 'access')
7879
self.assertEqual(data['user_claims'], user_claims)
@@ -86,16 +87,17 @@ def test_encode_invalid_access_token(self):
8687
# Check with non-serializable json
8788
with self.app.test_request_context():
8889
user_claims = datetime
90+
identity_claim = 'identity'
8991
with self.assertRaises(Exception):
9092
encode_access_token('user1', 'secret', 'HS256',
9193
timedelta(hours=1), True, user_claims,
92-
csrf=True, identity_claim='identity')
94+
csrf=True, identity_claim=identity_claim)
9395

9496
user_claims = {'foo': timedelta(hours=4)}
9597
with self.assertRaises(Exception):
9698
encode_access_token('user1', 'secret', 'HS256',
9799
timedelta(hours=1), True, user_claims,
98-
csrf=True, identity_claim='identity')
100+
csrf=True, identity_claim=identity_claim)
99101

100102
def test_encode_refresh_token(self):
101103
secret = 'super-totally-secret-key'
@@ -212,25 +214,27 @@ def test_decode_jwt(self):
212214

213215
def test_decode_invalid_jwt(self):
214216
with self.app.test_request_context():
217+
identity_claim = 'identity'
215218
# Verify underlying pyjwt expires verification works
216219
with self.assertRaises(jwt.ExpiredSignatureError):
217220
token_data = {
218221
'exp': datetime.utcnow() - timedelta(minutes=5),
219222
}
220223
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
221224
decode_jwt(encoded_token, 'secret', 'HS256',
222-
csrf=False, identity_claim='identity')
225+
csrf=False, identity_claim=identity_claim)
223226

224227
# Missing jti
225228
with self.assertRaises(JWTDecodeError):
229+
226230
token_data = {
227231
'exp': datetime.utcnow() + timedelta(minutes=5),
228-
'identity': 'banana',
232+
identity_claim: 'banana',
229233
'type': 'refresh'
230234
}
231235
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
232236
decode_jwt(encoded_token, 'secret', 'HS256',
233-
csrf=False, identity_claim='identity')
237+
csrf=False, identity_claim=identity_claim)
234238

235239
# Missing identity
236240
with self.assertRaises(JWTDecodeError):
@@ -241,83 +245,85 @@ def test_decode_invalid_jwt(self):
241245
}
242246
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
243247
decode_jwt(encoded_token, 'secret', 'HS256',
244-
csrf=False, identity_claim='identity')
248+
csrf=False, identity_claim=identity_claim)
245249

246250
# Non-matching identity claim
247251
with self.assertRaises(JWTDecodeError):
248252
token_data = {
249253
'exp': datetime.utcnow() + timedelta(minutes=5),
250-
'identity': 'banana',
254+
identity_claim: 'banana',
251255
'type': 'refresh'
252256
}
257+
other_identity_claim = 'sub'
253258
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
259+
self.assertNotEqual(identity_claim, other_identity_claim)
254260
decode_jwt(encoded_token, 'secret', 'HS256',
255-
csrf=False, identity_claim='sub')
261+
csrf=False, identity_claim=other_identity_claim)
256262

257263
# Missing type
258264
with self.assertRaises(JWTDecodeError):
259265
token_data = {
260266
'jti': 'banana',
261-
'identity': 'banana',
267+
identity_claim: 'banana',
262268
'exp': datetime.utcnow() + timedelta(minutes=5),
263269
}
264270
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
265271
decode_jwt(encoded_token, 'secret', 'HS256',
266-
csrf=False, identity_claim='identity')
272+
csrf=False, identity_claim=identity_claim)
267273

268274
# Missing fresh in access token
269275
with self.assertRaises(JWTDecodeError):
270276
token_data = {
271277
'jti': 'banana',
272-
'identity': 'banana',
278+
identity_claim: 'banana',
273279
'exp': datetime.utcnow() + timedelta(minutes=5),
274280
'type': 'access',
275281
'user_claims': {}
276282
}
277283
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
278284
decode_jwt(encoded_token, 'secret', 'HS256',
279-
csrf=False, identity_claim='identity')
285+
csrf=False, identity_claim=identity_claim)
280286

281287
# Missing user claims in access token
282288
with self.assertRaises(JWTDecodeError):
283289
token_data = {
284290
'jti': 'banana',
285-
'identity': 'banana',
291+
identity_claim: 'banana',
286292
'exp': datetime.utcnow() + timedelta(minutes=5),
287293
'type': 'access',
288294
'fresh': True
289295
}
290296
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
291297
decode_jwt(encoded_token, 'secret', 'HS256',
292-
csrf=False, identity_claim='identity')
298+
csrf=False, identity_claim=identity_claim)
293299

294300
# Bad token type
295301
with self.assertRaises(JWTDecodeError):
296302
token_data = {
297303
'jti': 'banana',
298-
'identity': 'banana',
304+
identity_claim: 'banana',
299305
'exp': datetime.utcnow() + timedelta(minutes=5),
300306
'type': 'banana',
301307
'fresh': True,
302308
'user_claims': 'banana'
303309
}
304310
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
305311
decode_jwt(encoded_token, 'secret', 'HS256',
306-
csrf=False, identity_claim='identity')
312+
csrf=False, identity_claim=identity_claim)
307313

308314
# Missing csrf in csrf enabled token
309315
with self.assertRaises(JWTDecodeError):
310316
token_data = {
311317
'jti': 'banana',
312-
'identity': 'banana',
318+
identity_claim: 'banana',
313319
'exp': datetime.utcnow() + timedelta(minutes=5),
314320
'type': 'access',
315321
'fresh': True,
316322
'user_claims': 'banana'
317323
}
318324
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
319325
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True,
320-
identity_claim='identity')
326+
identity_claim=identity_claim)
321327

322328
def test_create_jwt_with_object(self):
323329
# Complex object to test building a JWT from. Normally if you are using

tests/test_protected_endpoints.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setUp(self):
2222
self.app.config['JWT_ALGORITHM'] = 'HS256'
2323
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
2424
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
25+
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
2526
self.jwt_manager = JWTManager(self.app)
2627
self.client = self.app.test_client()
2728

@@ -454,6 +455,9 @@ def claims():
454455
claims_keys = [claim for claim in jwt]
455456
return jsonify(claims_keys), 200
456457

458+
# Grab custom identity claim
459+
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']
460+
457461
# Login
458462
response = self.client.post('/auth/login')
459463
data = json.loads(response.get_data(as_text=True))
@@ -466,7 +470,7 @@ def claims():
466470
self.assertIn('iat', data)
467471
self.assertIn('nbf', data)
468472
self.assertIn('jti', data)
469-
self.assertIn('identity', data)
473+
self.assertIn(identity_claim, data)
470474
self.assertIn('fresh', data)
471475
self.assertIn('type', data)
472476
self.assertIn('user_claims', data)
@@ -836,12 +840,13 @@ def test_access_endpoints_with_cookie_missing_csrf_field(self):
836840

837841
def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
838842
now = datetime.utcnow()
843+
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']
839844
token_data = {
840845
'exp': now + timedelta(minutes=5),
841846
'iat': now,
842847
'nbf': now,
843848
'jti': 'banana',
844-
'identity': 'banana',
849+
identity_claim: 'banana',
845850
'type': 'refresh',
846851
'csrf': 404
847852
}

0 commit comments

Comments
 (0)