Skip to content

Commit 03aedc3

Browse files
steinitzuvimalloc
authored andcommitted
@decode_key_loader callback added for dynamic decode keys (#191)
1 parent 877e522 commit 03aedc3

File tree

6 files changed

+60
-4
lines changed

6 files changed

+60
-4
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Configuring Flask-JWT-Extended
2424
.. automethod:: user_loader_callback_loader
2525
.. automethod:: user_loader_error_loader
2626
.. automethod:: unauthorized_loader
27+
.. automethod:: decode_key_loader
2728

2829

2930
Protected endpoint decorators

docs/changing_default_behavior.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ and what the return values of your callback functions need to be.
4343
- Function that is called to verify the user_claims data. Must return True or False
4444
* - :meth:`~flask_jwt_extended.JWTManager.claims_verification_failed_loader`
4545
- Function that is called when the user claims verification callback returns False
46+
* - :meth:`~flask_jwt_extended.JWTManager.decode_key_loader`
47+
- Function that is called to load the decode/secret key before verifying a token
4648

4749
Dynamic token expires time
4850
~~~~~~~~~~~~~~~~~~~~~~~~~~

flask_jwt_extended/default_callbacks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,7 @@ def default_verify_claims_failed_callback():
101101
error message with a 400 status code
102102
"""
103103
return jsonify({config.error_msg_key: 'User claims verification failed'}), 400
104+
105+
106+
def default_decode_key_callback(claims):
107+
return config.decode_key

flask_jwt_extended/jwt_manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
default_user_identity_callback, default_invalid_token_callback,
1414
default_unauthorized_callback, default_needs_fresh_token_callback,
1515
default_revoked_token_callback, default_user_loader_error_callback,
16-
default_claims_verification_callback, default_verify_claims_failed_callback
16+
default_claims_verification_callback, default_verify_claims_failed_callback,
17+
default_decode_key_callback
1718
)
1819
from flask_jwt_extended.tokens import (
1920
encode_refresh_token, encode_access_token
@@ -53,6 +54,7 @@ def __init__(self, app=None):
5354
self._token_in_blacklist_callback = None
5455
self._claims_verification_callback = default_claims_verification_callback
5556
self._verify_claims_failed_callback = default_verify_claims_failed_callback
57+
self._decode_key_callback = default_decode_key_callback
5658

5759
# Register this extension with the flask app now (if it is provided)
5860
if app is not None:
@@ -378,6 +380,21 @@ def claims_verification_failed_loader(self, callback):
378380
self._verify_claims_failed_callback = callback
379381
return callback
380382

383+
def decode_key_loader(self, callback):
384+
"""
385+
This decorator sets the callback function for getting the JWT decode key and
386+
can be used to dynamically choose the appropriate decode key based on token
387+
contents.
388+
The default implementation returns the decode key from config (either
389+
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY` depending on signing algorithm).
390+
391+
*HINT*: The callback function must be a function that takes only **one** argument,
392+
which is a dictionary of the claims encoded in the JWT and must return a *string*
393+
which is the decode key to verify the token.
394+
"""
395+
self._decode_key_callback = callback
396+
return callback
397+
381398
def _create_refresh_token(self, identity, expires_delta=None):
382399
if expires_delta is None:
383400
expires_delta = config.refresh_expires

flask_jwt_extended/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
RevokedTokenError, UserClaimsVerificationError, WrongTokenError
1212
)
1313
from flask_jwt_extended.tokens import decode_jwt
14+
import jwt
1415

1516

1617
# Proxy to access the current user
@@ -71,9 +72,14 @@ def decode_token(encoded_token, csrf_value=None):
7172
:param encoded_token: The encoded JWT to decode into a python dict.
7273
:param csrf_value: Expected CSRF double submit value (optional)
7374
"""
75+
jwt_manager = _get_jwt_manager()
76+
unverified_claims = jwt.decode(
77+
encoded_token, verify=False, algorithms=config.algorithm
78+
)
79+
secret = jwt_manager._decode_key_callback(unverified_claims)
7480
return decode_jwt(
7581
encoded_token=encoded_token,
76-
secret=config.decode_key,
82+
secret=secret,
7783
algorithm=config.algorithm,
7884
identity_claim_key=config.identity_claim_key,
7985
user_claims_key=config.user_claims_key,

tests/test_decode_tokens.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import timedelta
44

55
from flask import Flask
6-
from jwt import ExpiredSignatureError
6+
from jwt import ExpiredSignatureError, InvalidSignatureError
77

88
from flask_jwt_extended import (
99
JWTManager, create_access_token, decode_token, create_refresh_token,
@@ -48,7 +48,9 @@ def empty_user_loader_return(identity):
4848
# returned via the decode_token call
4949
with app.test_request_context():
5050
token = create_access_token('username')
51-
pure_decoded = jwt.decode(token, config.decode_key, algorithms=[config.algorithm])
51+
unverfied_claims = jwt.decode(token, verify=False, algorithms=[config.algorithm])
52+
decode_key = jwtM._decode_key_callback(unverfied_claims)
53+
pure_decoded = jwt.decode(token, decode_key, algorithms=[config.algorithm])
5254
assert config.user_claims_key not in pure_decoded
5355
extension_decoded = decode_token(token)
5456
assert config.user_claims_key in extension_decoded
@@ -117,3 +119,27 @@ def test_get_jti(app, default_access_token):
117119

118120
with app.test_request_context():
119121
assert default_access_token['jti'] == get_jti(token)
122+
123+
124+
def test_decode_key_callback(app, default_access_token):
125+
jwtM = get_jwt_manager(app)
126+
app.config['JWT_SECRET_KEY'] = 'correct secret'
127+
128+
@jwtM.decode_key_loader
129+
def get_decode_key_1(claims):
130+
return 'different secret'
131+
132+
assert jwtM._decode_key_callback({}) == 'different secret'
133+
134+
with pytest.raises(InvalidSignatureError):
135+
with app.test_request_context():
136+
token = encode_token(app, default_access_token)
137+
decode_token(token)
138+
139+
@jwtM.decode_key_loader
140+
def get_decode_key_2(claims):
141+
return 'correct secret'
142+
143+
with app.test_request_context():
144+
token = encode_token(app, default_access_token)
145+
decode_token(token)

0 commit comments

Comments
 (0)