Skip to content

Commit 1e1182a

Browse files
authored
Add callback function for customizing encode key (#193)
Refs #91
1 parent 03aedc3 commit 1e1182a

File tree

5 files changed

+78
-27
lines changed

5 files changed

+78
-27
lines changed

docs/api.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ Configuring Flask-JWT-Extended
1414
.. automethod:: init_app
1515
.. automethod:: claims_verification_loader
1616
.. automethod:: claims_verification_failed_loader
17+
.. automethod:: decode_key_loader
18+
.. automethod:: encode_key_loader
1719
.. automethod:: expired_token_loader
1820
.. automethod:: invalid_token_loader
1921
.. automethod:: needs_fresh_token_loader
2022
.. automethod:: revoked_token_loader
2123
.. automethod:: token_in_blacklist_loader
24+
.. automethod:: unauthorized_loader
2225
.. automethod:: user_claims_loader
2326
.. automethod:: user_identity_loader
2427
.. automethod:: user_loader_callback_loader
2528
.. automethod:: user_loader_error_loader
26-
.. automethod:: unauthorized_loader
27-
.. automethod:: decode_key_loader
2829

2930

3031
Protected endpoint decorators

docs/changing_default_behavior.rst

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,30 @@ and what the return values of your callback functions need to be.
2323

2424
* - Loader Decorator
2525
- Description
26+
* - :meth:`~flask_jwt_extended.JWTManager.claims_verification_loader`
27+
- Function that is called to verify the user_claims data. Must return True or False
28+
* - :meth:`~flask_jwt_extended.JWTManager.claims_verification_failed_loader`
29+
- Function that is called when the user claims verification callback returns False
30+
* - :meth:`~flask_jwt_extended.JWTManager.decode_key_loader`
31+
- Function that is called to get the decode key before verifying a token
32+
* - :meth:`~flask_jwt_extended.JWTManager.encode_key_loader`
33+
- Function that is called to get the encode key before creating a token
2634
* - :meth:`~flask_jwt_extended.JWTManager.expired_token_loader`
2735
- Function to call when an expired token accesses a protected endpoint
2836
* - :meth:`~flask_jwt_extended.JWTManager.invalid_token_loader`
2937
- Function to call when an invalid token accesses a protected endpoint
30-
* - :meth:`~flask_jwt_extended.JWTManager.unauthorized_loader`
31-
- Function to call when a request with no JWT accesses a protected endpoint
3238
* - :meth:`~flask_jwt_extended.JWTManager.needs_fresh_token_loader`
3339
- Function to call when a non-fresh token accesses a :func:`~flask_jwt_extended.fresh_jwt_required` endpoint
3440
* - :meth:`~flask_jwt_extended.JWTManager.revoked_token_loader`
3541
- Function to call when a revoked token accesses a protected endpoint
42+
* - :meth:`~flask_jwt_extended.JWTManager.token_in_blacklist_loader`
43+
- Function that is called to check if a token has been revoked
44+
* - :meth:`~flask_jwt_extended.JWTManager.unauthorized_loader`
45+
- Function to call when a request with no JWT accesses a protected endpoint
3646
* - :meth:`~flask_jwt_extended.JWTManager.user_loader_callback_loader`
3747
- Function to call to load a user object when token accesses a protected endpoint
3848
* - :meth:`~flask_jwt_extended.JWTManager.user_loader_error_loader`
3949
- Function that is called when the user_loader callback function returns `None`
40-
* - :meth:`~flask_jwt_extended.JWTManager.token_in_blacklist_loader`
41-
- Function that is called to check if a token has been revoked
42-
* - :meth:`~flask_jwt_extended.JWTManager.claims_verification_loader`
43-
- Function that is called to verify the user_claims data. Must return True or False
44-
* - :meth:`~flask_jwt_extended.JWTManager.claims_verification_failed_loader`
45-
- Function that is called when the user claims verification callback returns False
46-
* - :meth:`~flask_jwt_extended.JWTManager.decode_key_loader`
47-
- Function that is called to load the decode/secret key before verifying a token
4850

4951
Dynamic token expires time
5052
~~~~~~~~~~~~~~~~~~~~~~~~~~

flask_jwt_extended/default_callbacks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,16 @@ def default_verify_claims_failed_callback():
104104

105105

106106
def default_decode_key_callback(claims):
107+
"""
108+
By default, the decode key specified via the JWT_SECRET_KEY or
109+
JWT_PUBLIC_KEY settings will be used to decode all tokens
110+
"""
107111
return config.decode_key
112+
113+
114+
def default_encode_key_callback(identity):
115+
"""
116+
By default, the encode key specified via the JWT_SECRET_KEY or
117+
JWT_PRIVATE_KEY settings will be used to encode all tokens
118+
"""
119+
return config.encode_key

flask_jwt_extended/jwt_manager.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
default_unauthorized_callback, default_needs_fresh_token_callback,
1515
default_revoked_token_callback, default_user_loader_error_callback,
1616
default_claims_verification_callback, default_verify_claims_failed_callback,
17-
default_decode_key_callback
17+
default_decode_key_callback, default_encode_key_callback
1818
)
1919
from flask_jwt_extended.tokens import (
2020
encode_refresh_token, encode_access_token
@@ -55,6 +55,7 @@ def __init__(self, app=None):
5555
self._claims_verification_callback = default_claims_verification_callback
5656
self._verify_claims_failed_callback = default_verify_claims_failed_callback
5757
self._decode_key_callback = default_decode_key_callback
58+
self._encode_key_callback = default_encode_key_callback
5859

5960
# Register this extension with the flask app now (if it is provided)
6061
if app is not None:
@@ -385,16 +386,34 @@ def decode_key_loader(self, callback):
385386
This decorator sets the callback function for getting the JWT decode key and
386387
can be used to dynamically choose the appropriate decode key based on token
387388
contents.
388-
The default implementation returns the decode key from config (either
389-
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY` depending on signing algorithm).
389+
390+
The default implementation returns the decode key specified by
391+
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY`, depending on the signing algorithm.
390392
391393
*HINT*: The callback function must be a function that takes only **one** argument,
392-
which is a dictionary of the claims encoded in the JWT and must return a *string*
394+
which is the unverified claims of the jwt (dictionary) and must return a *string*
393395
which is the decode key to verify the token.
394396
"""
395397
self._decode_key_callback = callback
396398
return callback
397399

400+
def encode_key_loader(self, callback):
401+
"""
402+
This decorator sets the callback function for getting the JWT encode key and
403+
can be used to dynamically choose the appropriate encode key based on the
404+
token identity.
405+
406+
The default implementation returns the encode key specified by
407+
`JWT_SECRET_KEY` or `JWT_PRIVATE_KEY`, depending on the signing algorithm.
408+
409+
*HINT*: The callback function must be a function that takes only **one**
410+
argument, which is the identity as passed into the create_access_token
411+
or create_refresh_token functions, and must return a *string* which is
412+
the decode key to verify the token.
413+
"""
414+
self._encode_key_callback = callback
415+
return callback
416+
398417
def _create_refresh_token(self, identity, expires_delta=None):
399418
if expires_delta is None:
400419
expires_delta = config.refresh_expires
@@ -406,7 +425,7 @@ def _create_refresh_token(self, identity, expires_delta=None):
406425

407426
refresh_token = encode_refresh_token(
408427
identity=self._user_identity_callback(identity),
409-
secret=config.encode_key,
428+
secret=self._encode_key_callback(identity),
410429
algorithm=config.algorithm,
411430
expires_delta=expires_delta,
412431
user_claims=user_claims,
@@ -423,7 +442,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None):
423442

424443
access_token = encode_access_token(
425444
identity=self._user_identity_callback(identity),
426-
secret=config.encode_key,
445+
secret=self._encode_key_callback(identity),
427446
algorithm=config.algorithm,
428447
expires_delta=expires_delta,
429448
fresh=fresh,

tests/test_decode_tokens.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def empty_user_loader_return(identity):
4848
# returned via the decode_token call
4949
with app.test_request_context():
5050
token = create_access_token('username')
51-
unverfied_claims = jwt.decode(token, verify=False, algorithms=[config.algorithm])
52-
decode_key = jwtM._decode_key_callback(unverfied_claims)
53-
pure_decoded = jwt.decode(token, decode_key, algorithms=[config.algorithm])
51+
pure_decoded = jwt.decode(token, config.decode_key, algorithms=[config.algorithm])
5452
assert config.user_claims_key not in pure_decoded
5553
extension_decoded = decode_token(token)
5654
assert config.user_claims_key in extension_decoded
@@ -121,25 +119,44 @@ def test_get_jti(app, default_access_token):
121119
assert default_access_token['jti'] == get_jti(token)
122120

123121

124-
def test_decode_key_callback(app, default_access_token):
122+
def test_encode_decode_callback_values(app, default_access_token):
125123
jwtM = get_jwt_manager(app)
126-
app.config['JWT_SECRET_KEY'] = 'correct secret'
124+
app.config['JWT_SECRET_KEY'] = 'foobarbaz'
125+
with app.test_request_context():
126+
assert jwtM._decode_key_callback({}) == 'foobarbaz'
127+
assert jwtM._encode_key_callback({}) == 'foobarbaz'
127128

128129
@jwtM.decode_key_loader
129130
def get_decode_key_1(claims):
130131
return 'different secret'
131132

133+
@jwtM.encode_key_loader
134+
def get_decode_key_2(identity):
135+
return 'different secret'
136+
132137
assert jwtM._decode_key_callback({}) == 'different secret'
138+
assert jwtM._encode_key_callback('') == 'different secret'
139+
140+
141+
def test_custom_encode_decode_key_callbacks(app, default_access_token):
142+
jwtM = get_jwt_manager(app)
143+
app.config['JWT_SECRET_KEY'] = 'foobarbaz'
144+
145+
@jwtM.encode_key_loader
146+
def get_encode_key_1(identity):
147+
assert identity == 'username'
148+
return 'different secret'
133149

134150
with pytest.raises(InvalidSignatureError):
135151
with app.test_request_context():
136-
token = encode_token(app, default_access_token)
152+
token = create_access_token('username')
137153
decode_token(token)
138154

139155
@jwtM.decode_key_loader
140-
def get_decode_key_2(claims):
141-
return 'correct secret'
156+
def get_decode_key_1(claims):
157+
assert claims['identity'] == 'username'
158+
return 'different secret'
142159

143160
with app.test_request_context():
144-
token = encode_token(app, default_access_token)
161+
token = create_access_token('username')
145162
decode_token(token)

0 commit comments

Comments
 (0)