Skip to content

Commit 234aa6e

Browse files
acrossenvimalloc
authored andcommitted
Make token decoding more flexible
Fixes #208
1 parent 6fe88c7 commit 234aa6e

File tree

8 files changed

+128
-28
lines changed

8 files changed

+128
-28
lines changed

docs/options.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ General Options:
4444
``JWT_ERROR_MESSAGE_KEY`` The key of the error message in a JSON error response when using
4545
the default error handlers.
4646
Defaults to ``'msg'``.
47+
``JWT_DECODE_AUDIENCE`` The audience you expect in a JWT when decoding it.
48+
If this option differs from the 'aud' claim in a JWT, the ``'invalid_token_callback'`` is invoked.
49+
Defaults to ``'None'``.
4750
================================= =========================================
4851

4952

flask_jwt_extended/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,5 +278,9 @@ def error_msg_key(self):
278278
def json_encoder(self):
279279
return current_app.json_encoder
280280

281+
@property
282+
def audience(self):
283+
return current_app.config['JWT_DECODE_AUDIENCE']
284+
281285

282286
config = _Config()

flask_jwt_extended/default_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def default_verify_claims_failed_callback():
103103
return jsonify({config.error_msg_key: 'User claims verification failed'}), 400
104104

105105

106-
def default_decode_key_callback(claims):
106+
def default_decode_key_callback(claims, headers):
107107
"""
108108
By default, the decode key specified via the JWT_SECRET_KEY or
109109
JWT_PUBLIC_KEY settings will be used to decode all tokens

flask_jwt_extended/jwt_manager.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22

3-
from jwt import ExpiredSignatureError, InvalidTokenError
3+
from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError
44

55
from flask_jwt_extended.config import config
66
from flask_jwt_extended.exceptions import (
@@ -108,6 +108,10 @@ def handle_jwt_decode_error(e):
108108
def handle_wrong_token_error(e):
109109
return self._invalid_token_callback(str(e))
110110

111+
@app.errorhandler(InvalidAudienceError)
112+
def handle_invalid_audience_error(e):
113+
return self._invalid_token_callback(str(e))
114+
111115
@app.errorhandler(RevokedTokenError)
112116
def handle_revoked_token_error(e):
113117
return self._revoked_token_callback()
@@ -192,6 +196,7 @@ def _set_default_configuration_options(app):
192196

193197
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
194198
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
199+
app.config.setdefault('JWT_DECODE_AUDIENCE', None)
195200

196201
app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)
197202

@@ -390,9 +395,10 @@ def decode_key_loader(self, callback):
390395
The default implementation returns the decode key specified by
391396
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY`, depending on the signing algorithm.
392397
393-
*HINT*: The callback function must be a function that takes only **one** argument,
394-
which is the unverified claims of the jwt (dictionary) and must return a *string*
395-
which is the decode key to verify the token.
398+
*HINT*: The callback function should be a function that takes
399+
**two** arguments, which are the unverified claims and headers of the jwt
400+
(dictionaries). The function must return a *string* which is the decode key
401+
in PEM format to verify the token.
396402
"""
397403
self._decode_key_callback = callback
398404
return callback

flask_jwt_extended/tokens.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _create_csrf_token():
1515

1616
def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
1717
json_encoder=None):
18-
uid = str(uuid.uuid4())
18+
uid = _create_csrf_token()
1919
now = datetime.datetime.utcnow()
2020
token_data = {
2121
'iat': now,
@@ -113,7 +113,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
113113

114114

115115
def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
116-
user_claims_key, csrf_value=None):
116+
user_claims_key, csrf_value=None, audience=None):
117117
"""
118118
Decodes an encoded JWT
119119
@@ -123,21 +123,24 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
123123
:param identity_claim_key: expected key that contains the identity
124124
:param user_claims_key: expected key that contains the user claims
125125
:param csrf_value: Expected double submit csrf value
126+
:param audience: expected audience in the JWT
126127
:return: Dictionary containing contents of the JWT
127128
"""
128-
# This call verifies the ext, iat, and nbf claims
129-
data = jwt.decode(encoded_token, secret, algorithms=[algorithm])
129+
# This call verifies the ext, iat, nbf, and aud claims
130+
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience)
130131

131132
# Make sure that any custom claims we expect in the token are present
132133
if 'jti' not in data:
133-
raise JWTDecodeError("Missing claim: jti")
134+
data['jti'] = None
134135
if identity_claim_key not in data:
135136
raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
136-
if 'type' not in data or data['type'] not in ('refresh', 'access'):
137+
if 'type' not in data:
138+
data['type'] = 'access'
139+
if data['type'] not in ('refresh', 'access'):
137140
raise JWTDecodeError("Missing or invalid claim: type")
138141
if data['type'] == 'access':
139142
if 'fresh' not in data:
140-
raise JWTDecodeError("Missing claim: fresh")
143+
data['fresh'] = False
141144
if user_claims_key not in data:
142145
data[user_claims_key] = {}
143146
if csrf_value:

flask_jwt_extended/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from flask import current_app
22
from werkzeug.local import LocalProxy
3+
from warnings import warn
34

45
try:
56
from flask import _app_ctx_stack as ctx_stack
@@ -76,14 +77,26 @@ def decode_token(encoded_token, csrf_value=None):
7677
unverified_claims = jwt.decode(
7778
encoded_token, verify=False, algorithms=config.algorithm
7879
)
79-
secret = jwt_manager._decode_key_callback(unverified_claims)
80+
unverified_headers = jwt.get_unverified_header(encoded_token)
81+
# Attempt to call callback with both claims and headers, but fallback to just claims
82+
# for backwards compatibility
83+
try:
84+
secret = jwt_manager._decode_key_callback(unverified_claims, unverified_headers)
85+
except TypeError:
86+
msg = (
87+
"The single-argument (unverified_claims) form of decode_key_callback is deprecated. "
88+
"Update your code to use the two-argument form (unverified_claims, unverified_headers)."
89+
)
90+
warn(msg, DeprecationWarning)
91+
secret = jwt_manager._decode_key_callback(unverified_claims)
8092
return decode_jwt(
8193
encoded_token=encoded_token,
8294
secret=secret,
8395
algorithm=config.algorithm,
8496
identity_claim_key=config.identity_claim_key,
8597
user_claims_key=config.user_claims_key,
86-
csrf_value=csrf_value
98+
csrf_value=csrf_value,
99+
audience=config.audience
87100
)
88101

89102

tests/test_decode_tokens.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import jwt
22
import pytest
3-
from datetime import timedelta
3+
from datetime import datetime, timedelta
4+
import warnings
45

56
from flask import Flask
6-
from jwt import ExpiredSignatureError, InvalidSignatureError
7+
from jwt import ExpiredSignatureError, InvalidSignatureError, InvalidAudienceError
78

89
from flask_jwt_extended import (
910
JWTManager, create_access_token, decode_token, create_refresh_token,
@@ -54,16 +55,29 @@ def empty_user_loader_return(identity):
5455
assert config.user_claims_key in extension_decoded
5556

5657

57-
@pytest.mark.parametrize("missing_claim", ['jti', 'type', 'identity', 'fresh', 'csrf'])
58-
def test_missing_jti_claim(app, default_access_token, missing_claim):
59-
del default_access_token[missing_claim]
58+
@pytest.mark.parametrize("missing_claims", ['identity', 'csrf'])
59+
def test_missing_claims(app, default_access_token, missing_claims):
60+
del default_access_token[missing_claims]
6061
missing_jwt_token = encode_token(app, default_access_token)
6162

6263
with pytest.raises(JWTDecodeError):
6364
with app.test_request_context():
6465
decode_token(missing_jwt_token, csrf_value='abcd')
6566

6667

68+
def test_default_decode_token_values(app, default_access_token):
69+
del default_access_token['type']
70+
del default_access_token['jti']
71+
del default_access_token['fresh']
72+
token = encode_token(app, default_access_token)
73+
74+
with app.test_request_context():
75+
decoded = decode_token(token)
76+
assert decoded['type'] == 'access'
77+
assert decoded['jti'] is None
78+
assert decoded['fresh'] is False
79+
80+
6781
def test_bad_token_type(app, default_access_token):
6882
default_access_token['type'] = 'banana'
6983
bad_type_token = encode_token(app, default_access_token)
@@ -123,19 +137,36 @@ def test_encode_decode_callback_values(app, default_access_token):
123137
jwtM = get_jwt_manager(app)
124138
app.config['JWT_SECRET_KEY'] = 'foobarbaz'
125139
with app.test_request_context():
126-
assert jwtM._decode_key_callback({}) == 'foobarbaz'
140+
assert jwtM._decode_key_callback({}, {}) == 'foobarbaz'
127141
assert jwtM._encode_key_callback({}) == 'foobarbaz'
128142

129-
@jwtM.decode_key_loader
130-
def get_decode_key_1(claims):
143+
@jwtM.encode_key_loader
144+
def get_encode_key_1(identity):
131145
return 'different secret'
146+
assert jwtM._encode_key_callback('') == 'different secret'
132147

133-
@jwtM.encode_key_loader
134-
def get_decode_key_2(identity):
148+
@jwtM.decode_key_loader
149+
def get_decode_key_1(claims, headers):
135150
return 'different secret'
151+
assert jwtM._decode_key_callback({}, {}) == 'different secret'
136152

137-
assert jwtM._decode_key_callback({}) == 'different secret'
138-
assert jwtM._encode_key_callback('') == 'different secret'
153+
154+
def test_legacy_decode_key_callback(app, default_access_token):
155+
jwtM = get_jwt_manager(app)
156+
app.config['JWT_SECRET_KEY'] = 'foobarbaz'
157+
158+
# test decode key callback with one argument (backwards compatibility)
159+
with warnings.catch_warnings(record=True) as w:
160+
warnings.simplefilter("always")
161+
162+
@jwtM.decode_key_loader
163+
def get_decode_key_legacy(claims):
164+
return 'foobarbaz'
165+
with app.test_request_context():
166+
token = encode_token(app, default_access_token)
167+
decode_token(token)
168+
assert len(w) == 1
169+
assert issubclass(w[-1].category, DeprecationWarning)
139170

140171

141172
def test_custom_encode_decode_key_callbacks(app, default_access_token):
@@ -157,7 +188,7 @@ def get_encode_key_1(identity):
157188
decode_token(token)
158189

159190
@jwtM.decode_key_loader
160-
def get_decode_key_1(claims):
191+
def get_decode_key_1(claims, headers):
161192
assert claims['identity'] == 'username'
162193
return 'different secret'
163194

@@ -166,3 +197,19 @@ def get_decode_key_1(claims):
166197
decode_token(token)
167198
token = create_refresh_token('username')
168199
decode_token(token)
200+
201+
202+
def test_valid_aud(app, default_access_token):
203+
app.config['JWT_DECODE_AUDIENCE'] = 'foo'
204+
205+
default_access_token['aud'] = 'bar'
206+
invalid_token = encode_token(app, default_access_token)
207+
with pytest.raises(InvalidAudienceError):
208+
with app.test_request_context():
209+
decode_token(invalid_token)
210+
211+
default_access_token['aud'] = 'foo'
212+
valid_token = encode_token(app, default_access_token)
213+
with app.test_request_context():
214+
decoded = decode_token(valid_token)
215+
assert decoded['aud'] == 'foo'

tests/test_view_decorators.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,31 @@ def test_jwt_missing_claims(app):
207207

208208
response = test_client.get(url, headers=make_headers(token))
209209
assert response.status_code == 422
210-
assert response.get_json() == {'msg': 'Missing claim: jti'}
210+
assert response.get_json() == {'msg': 'Missing claim: identity'}
211+
212+
213+
def test_jwt_invalid_audience(app):
214+
url = '/protected'
215+
jwtM = get_jwt_manager(app)
216+
test_client = app.test_client()
217+
218+
# No audience claim expected or provided - OK
219+
access_token = encode_token(app, {'identity': 'me'})
220+
response = test_client.get(url, headers=make_headers(access_token))
221+
assert response.status_code == 200
222+
223+
# Audience claim expected and not provided - not OK
224+
app.config['JWT_DECODE_AUDIENCE'] = 'my_audience'
225+
access_token = encode_token(app, {'identity': 'me'})
226+
response = test_client.get(url, headers=make_headers(access_token))
227+
assert response.status_code == 422
228+
assert response.get_json() == {'msg': 'Token is missing the "aud" claim'}
229+
230+
# Audience claim still expected and wrong one provided - not OK
231+
access_token = encode_token(app, {'aud': 'different_audience', 'identity': 'me'})
232+
response = test_client.get(url, headers=make_headers(access_token))
233+
assert response.status_code == 422
234+
assert response.get_json() == {'msg': 'Invalid audience'}
211235

212236

213237
def test_expired_token(app):

0 commit comments

Comments
 (0)