Skip to content

Commit df68d65

Browse files
committed
Allow changing subject claim when decoding
Related to issue #65
1 parent 17c3254 commit df68d65

File tree

8 files changed

+65
-22
lines changed

8 files changed

+65
-22
lines changed

docs/options.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ General Options:
3030
such as ``RS*`` or ``ES*``. PEM format expected.
3131
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
3232
such as ``RS*`` or ``ES*``. PEM format expected.
33+
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity.
34+
For interoperativity, the JWT RFC recommends using ``'sub'``.
35+
Defaults to ``'identity'``.
3336
================================= =========================================
3437

3538

flask_jwt_extended/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def cookie_max_age(self):
223223
# seconds a long ways in the future
224224
return None if self.session_cookie else 2147483647 # 2^31
225225

226+
@property
227+
def identity_claim(self):
228+
return current_app.config['JWT_IDENTITY_CLAIM']
229+
226230
config = _Config()
227231

228232

flask_jwt_extended/jwt_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def _set_default_configuration_options(app):
164164
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
165165
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh'])
166166

167+
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
168+
167169
def user_claims_loader(self, callback):
168170
"""
169171
This sets the callback method for adding custom user claims to a JWT.

flask_jwt_extended/tokens.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
7676
return _encode_jwt(token_data, expires_delta, secret, algorithm)
7777

7878

79-
def decode_jwt(encoded_token, secret, algorithm, csrf):
79+
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
8080
"""
8181
Decodes an encoded JWT
8282
@@ -85,6 +85,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
8585
:param algorithm: Algorithm used to encode the JWT
8686
:param csrf: If this token is expected to have a CSRF double submit
8787
value present (boolean)
88+
:param identity_claim: expected claim that is used to identify the subject
8889
:return: Dictionary containing contents of the JWT
8990
"""
9091
# This call verifies the ext, iat, and nbf claims
@@ -93,8 +94,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
9394
# Make sure that any custom claims we expect in the token are present
9495
if 'jti' not in data:
9596
raise JWTDecodeError("Missing claim: jti")
96-
if 'identity' not in data:
97-
raise JWTDecodeError("Missing claim: identity")
97+
if identity_claim not in data:
98+
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
9899
if 'type' not in data or data['type'] not in ('refresh', 'access'):
99100
raise JWTDecodeError("Missing or invalid claim: type")
100101
if data['type'] == 'access':

flask_jwt_extended/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_jwt_identity():
2727
Returns the identity of the JWT in this context. If no JWT is present,
2828
None is returned.
2929
"""
30-
return get_raw_jwt().get('identity', None)
30+
return get_raw_jwt().get(config.identity_claim, None)
3131

3232

3333
def get_jwt_claims():
@@ -63,7 +63,8 @@ def decode_token(encoded_token):
6363
encoded_token=encoded_token,
6464
secret=config.decode_key,
6565
algorithm=config.algorithm,
66-
csrf=config.csrf_protect
66+
csrf=config.csrf_protect,
67+
identity_claim=config.identity_claim
6768
)
6869

6970

@@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs):
106107

107108

108109
def get_csrf_token(encoded_token):
109-
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
110+
token = decode_jwt(
111+
encoded_token,
112+
config.decode_key,
113+
config.algorithm,
114+
csrf=True,
115+
identity_claim=config.identity_claim
116+
)
110117
return token['csrf']
111118

112119

flask_jwt_extended/view_decorators.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,13 @@ def _decode_jwt_from_headers():
144144
raise InvalidHeaderError(msg)
145145
token = parts[1]
146146

147-
return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)
147+
return decode_jwt(
148+
encoded_token=token,
149+
secret=config.decode_key,
150+
algorithm=config.algorithm,
151+
csrf=False,
152+
identity_claim=config.identity_claim
153+
)
148154

149155

150156
def _decode_jwt_from_cookies(request_type):
@@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type):
163169
encoded_token=encoded_token,
164170
secret=config.decode_key,
165171
algorithm=config.algorithm,
166-
csrf=config.csrf_protect
172+
csrf=config.csrf_protect,
173+
identity_claim=config.identity_claim
167174
)
168175

169176
# Verify csrf double submit tokens match if required

tests/test_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def test_default_configs(self):
5454
self.assertEqual(config.decode_key, self.app.secret_key)
5555
self.assertEqual(config.cookie_max_age, None)
5656

57+
self.assertEqual(config.identity_claim, 'identity')
58+
5759
def test_override_configs(self):
5860
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
5961
self.app.config['JWT_HEADER_NAME'] = 'TestHeader'
@@ -86,6 +88,8 @@ def test_override_configs(self):
8688

8789
self.app.secret_key = 'banana'
8890

91+
self.app.config['JWT_IDENTITY_CLAIM'] = 'foo'
92+
8993
with self.app.test_request_context():
9094
self.assertEqual(config.token_location, ['cookies'])
9195
self.assertEqual(config.jwt_in_cookies, True)
@@ -122,6 +126,8 @@ def test_override_configs(self):
122126

123127
self.assertEqual(config.cookie_max_age, 2147483647)
124128

129+
self.assertEqual(config.identity_claim, 'foo')
130+
125131
def test_invalid_config_options(self):
126132
with self.app.test_request_context():
127133
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'

tests/test_jwt_encode_decode.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_decode_jwt(self):
157157
'user_claims': {'foo': 'bar'},
158158
}
159159
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
160-
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
160+
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
161161
self.assertIn('exp', data)
162162
self.assertIn('iat', data)
163163
self.assertIn('nbf', data)
@@ -188,7 +188,7 @@ def test_decode_jwt(self):
188188
'type': 'refresh',
189189
}
190190
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
191-
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
191+
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
192192
self.assertIn('exp', data)
193193
self.assertIn('iat', data)
194194
self.assertIn('nbf', data)
@@ -210,7 +210,7 @@ def test_decode_invalid_jwt(self):
210210
'exp': datetime.utcnow() - timedelta(minutes=5),
211211
}
212212
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
213-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
213+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
214214

215215
# Missing jti
216216
with self.assertRaises(JWTDecodeError):
@@ -220,7 +220,7 @@ def test_decode_invalid_jwt(self):
220220
'type': 'refresh'
221221
}
222222
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
223-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
223+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
224224

225225
# Missing identity
226226
with self.assertRaises(JWTDecodeError):
@@ -230,7 +230,17 @@ def test_decode_invalid_jwt(self):
230230
'type': 'refresh'
231231
}
232232
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
233-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
233+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
234+
235+
# Non-matching identity claim
236+
with self.assertRaises(JWTDecodeError):
237+
token_data = {
238+
'exp': datetime.utcnow() + timedelta(minutes=5),
239+
'identity': 'banana',
240+
'type': 'refresh'
241+
}
242+
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
243+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub')
234244

235245
# Missing type
236246
with self.assertRaises(JWTDecodeError):
@@ -240,7 +250,7 @@ def test_decode_invalid_jwt(self):
240250
'exp': datetime.utcnow() + timedelta(minutes=5),
241251
}
242252
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
243-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
253+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
244254

245255
# Missing fresh in access token
246256
with self.assertRaises(JWTDecodeError):
@@ -252,7 +262,7 @@ def test_decode_invalid_jwt(self):
252262
'user_claims': {}
253263
}
254264
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
255-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
265+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
256266

257267
# Missing user claims in access token
258268
with self.assertRaises(JWTDecodeError):
@@ -264,7 +274,7 @@ def test_decode_invalid_jwt(self):
264274
'fresh': True
265275
}
266276
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
267-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
277+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
268278

269279
# Bad token type
270280
with self.assertRaises(JWTDecodeError):
@@ -277,7 +287,7 @@ def test_decode_invalid_jwt(self):
277287
'user_claims': 'banana'
278288
}
279289
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
280-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
290+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
281291

282292
# Missing csrf in csrf enabled token
283293
with self.assertRaises(JWTDecodeError):
@@ -290,7 +300,7 @@ def test_decode_invalid_jwt(self):
290300
'user_claims': 'banana'
291301
}
292302
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
293-
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True)
303+
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity')
294304

295305
def test_create_jwt_with_object(self):
296306
# Complex object to test building a JWT from. Normally if you are using
@@ -322,12 +332,15 @@ def user_identity_lookup(user):
322332
user = TestUser(username='foo', roles=['bar', 'baz'])
323333
access_token = create_access_token(identity=user)
324334
refresh_token = create_refresh_token(identity=user)
335+
identity = 'identity'
325336

326337
# Decode the tokens and make sure the values are set properly
327338
access_token_data = decode_jwt(access_token, app.secret_key,
328-
app.config['JWT_ALGORITHM'], csrf=False)
339+
app.config['JWT_ALGORITHM'], csrf=False,
340+
identity_claim=identity)
329341
refresh_token_data = decode_jwt(refresh_token, app.secret_key,
330-
app.config['JWT_ALGORITHM'], csrf=False)
331-
self.assertEqual(access_token_data['identity'], 'foo')
342+
app.config['JWT_ALGORITHM'], csrf=False,
343+
identity_claim=identity)
344+
self.assertEqual(access_token_data[identity], 'foo')
332345
self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz'])
333-
self.assertEqual(refresh_token_data['identity'], 'foo')
346+
self.assertEqual(refresh_token_data[identity], 'foo')

0 commit comments

Comments
 (0)