Skip to content

Commit 29b3f5a

Browse files
committed
Internal refactoring
1 parent a0bc992 commit 29b3f5a

File tree

7 files changed

+66
-83
lines changed

7 files changed

+66
-83
lines changed

flask_jwt_extended/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class _Config(object):
2323
should be done with flasks ```app.config```.
2424
2525
Default values for the configuration options are set in the jwt_manager
26-
object. All of these values are read only.
26+
object. All of these values are read only. This is simply a loose wrapper
27+
with some helper functionality for flasks `app.config`.
2728
"""
2829

2930
@property
@@ -224,11 +225,11 @@ def cookie_max_age(self):
224225
return None if self.session_cookie else 2147483647 # 2^31
225226

226227
@property
227-
def identity_claim(self):
228+
def identity_claim_key(self):
228229
return current_app.config['JWT_IDENTITY_CLAIM']
229230

230231
@property
231-
def user_claims(self):
232+
def user_claims_key(self):
232233
return current_app.config['JWT_USER_CLAIMS']
233234

234235
config = _Config()

flask_jwt_extended/jwt_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _create_refresh_token(self, identity, expires_delta=None):
381381
algorithm=config.algorithm,
382382
expires_delta=expires_delta,
383383
csrf=config.csrf_protect,
384-
identity_claim=config.identity_claim
384+
identity_claim_key=config.identity_claim_key,
385385
)
386386
return refresh_token
387387

@@ -397,7 +397,8 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None):
397397
fresh=fresh,
398398
user_claims=self._user_claims_callback(identity),
399399
csrf=config.csrf_protect,
400-
identity_claim=config.identity_claim
400+
identity_claim_key=config.identity_claim_key,
401+
user_claims_key=config.user_claims_key
401402
)
402403
return access_token
403404

flask_jwt_extended/tokens.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import uuid
33

44
import jwt
5+
from werkzeug.security import safe_str_cmp
56

6-
from flask_jwt_extended.exceptions import JWTDecodeError
7-
from flask_jwt_extended.config import config
7+
from flask_jwt_extended.exceptions import JWTDecodeError, CSRFError
88

99

1010
def _create_csrf_token():
@@ -26,7 +26,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):
2626

2727

2828
def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
29-
user_claims, csrf, identity_claim):
29+
user_claims, csrf, identity_claim_key, user_claims_key):
3030
"""
3131
Creates a new encoded (utf-8) access token.
3232
@@ -41,26 +41,27 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
4141
be json serializable
4242
:param csrf: Whether to include a csrf double submit claim in this token
4343
(boolean)
44-
:param identity_claim: Which claim should be used to store the identity in
44+
:param identity_claim_key: Which key should be used to store the identity
45+
:param user_claims_key: Which key should be used to store the user claims
4546
:return: Encoded access token
4647
"""
47-
# Create the jwt
4848
token_data = {
49-
identity_claim: identity,
49+
identity_claim_key: identity,
5050
'fresh': fresh,
5151
'type': 'access',
5252
}
5353

54-
# Add `user_claims` only is not empty or None.
54+
# Don't add extra data to the token if user_claims is empty.
5555
if user_claims:
56-
token_data[config.user_claims] = user_claims
56+
token_data[user_claims_key] = user_claims
5757

5858
if csrf:
5959
token_data['csrf'] = _create_csrf_token()
6060
return _encode_jwt(token_data, expires_delta, secret, algorithm)
6161

6262

63-
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
63+
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
64+
identity_claim_key):
6465
"""
6566
Creates a new encoded (utf-8) refresh token.
6667
@@ -71,28 +72,29 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, ident
7172
(datetime.timedelta)
7273
:param csrf: Whether to include a csrf double submit claim in this token
7374
(boolean)
74-
:param identity_claim: Which claim should be used to store the identity in
75+
:param identity_claim_key: Which key should be used to store the identity
7576
:return: Encoded refresh token
7677
"""
7778
token_data = {
78-
identity_claim: identity,
79+
identity_claim_key: identity,
7980
'type': 'refresh',
8081
}
8182
if csrf:
8283
token_data['csrf'] = _create_csrf_token()
8384
return _encode_jwt(token_data, expires_delta, secret, algorithm)
8485

8586

86-
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
87+
def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
88+
user_claims_key, csrf_value=None):
8789
"""
8890
Decodes an encoded JWT
8991
9092
:param encoded_token: The encoded JWT string to decode
9193
:param secret: Secret key used to encode the JWT
9294
:param algorithm: Algorithm used to encode the JWT
93-
:param csrf: If this token is expected to have a CSRF double submit
94-
value present (boolean)
95-
:param identity_claim: expected claim that is used to identify the subject
95+
:param identity_claim_key: expected key that contains the identity
96+
:param user_claims_key: expected key that contains the user claims
97+
:param csrf_value: Expected double submit csrf value
9698
:return: Dictionary containing contents of the JWT
9799
"""
98100
# This call verifies the ext, iat, and nbf claims
@@ -101,16 +103,18 @@ def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
101103
# Make sure that any custom claims we expect in the token are present
102104
if 'jti' not in data:
103105
raise JWTDecodeError("Missing claim: jti")
104-
if identity_claim not in data:
105-
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
106+
if identity_claim_key not in data:
107+
raise JWTDecodeError("Missing claim: {}".format(identity_claim_key))
106108
if 'type' not in data or data['type'] not in ('refresh', 'access'):
107109
raise JWTDecodeError("Missing or invalid claim: type")
108110
if data['type'] == 'access':
109111
if 'fresh' not in data:
110112
raise JWTDecodeError("Missing claim: fresh")
111-
if config.user_claims not in data:
112-
data[config.user_claims] = {}
113-
if csrf:
113+
if user_claims_key not in data:
114+
data[user_claims_key] = {}
115+
if csrf_value:
114116
if 'csrf' not in data:
115117
raise JWTDecodeError("Missing claim: csrf")
118+
if not safe_str_cmp(data['csrf'], csrf_value):
119+
raise CSRFError("CSRF double submit tokens do not match")
116120
return data

flask_jwt_extended/utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def get_jwt_identity():
2828
In a protected endpoint, this will return the identity of the JWT that is
2929
accessing this endpoint. If no JWT is present,`None` is returned instead.
3030
"""
31-
return get_raw_jwt().get(config.identity_claim, None)
31+
return get_raw_jwt().get(config.identity_claim_key, None)
3232

3333

3434
def get_jwt_claims():
@@ -37,7 +37,7 @@ def get_jwt_claims():
3737
in the JWT that is accessing the endpoint. If no custom user claims are
3838
present, an empty dict is returned instead.
3939
"""
40-
return get_raw_jwt().get(config.user_claims, {})
40+
return get_raw_jwt().get(config.user_claims_key, {})
4141

4242

4343
def get_current_user():
@@ -60,19 +60,21 @@ def get_jti(encoded_token):
6060
return decode_token(encoded_token).get('jti')
6161

6262

63-
def decode_token(encoded_token):
63+
def decode_token(encoded_token, csrf_value=None):
6464
"""
6565
Returns the decoded token (python dict) from an encoded JWT. This does all
6666
the checks to insure that the decoded token is valid before returning it.
6767
6868
:param encoded_token: The encoded JWT to decode into a python dict.
69+
:param csrf_value: Expected CSRF double submit value (optional)
6970
"""
7071
return decode_jwt(
7172
encoded_token=encoded_token,
7273
secret=config.decode_key,
7374
algorithm=config.algorithm,
74-
csrf=config.csrf_protect,
75-
identity_claim=config.identity_claim
75+
identity_claim_key=config.identity_claim_key,
76+
user_claims_key=config.user_claims_key,
77+
csrf_value=csrf_value
7678
)
7779

7880

@@ -153,13 +155,7 @@ def verify_token_claims(*args, **kwargs):
153155

154156

155157
def get_csrf_token(encoded_token):
156-
token = decode_jwt(
157-
encoded_token,
158-
config.decode_key,
159-
config.algorithm,
160-
csrf=True,
161-
identity_claim=config.identity_claim
162-
)
158+
token = decode_token(encoded_token)
163159
return token['csrf']
164160

165161

flask_jwt_extended/view_decorators.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from functools import wraps
22

33
from flask import request
4-
from werkzeug.security import safe_str_cmp
54
try:
65
from flask import _app_ctx_stack as ctx_stack
76
except ImportError: # pragma: no cover
@@ -13,9 +12,8 @@
1312
FreshTokenRequired, CSRFError, UserLoadError, RevokedTokenError,
1413
UserClaimsVerificationError
1514
)
16-
from flask_jwt_extended.tokens import decode_jwt
1715
from flask_jwt_extended.utils import (
18-
has_user_loader, user_loader, token_in_blacklist,
16+
has_user_loader, user_loader, token_in_blacklist, decode_token,
1917
has_token_in_blacklist_callback, verify_token_claims
2018
)
2119

@@ -34,9 +32,9 @@ def jwt_required(fn):
3432
def wrapper(*args, **kwargs):
3533
jwt_data = _decode_jwt_from_request(request_type='access')
3634
ctx_stack.top.jwt = jwt_data
37-
if not verify_token_claims(jwt_data[config.user_claims]):
35+
if not verify_token_claims(jwt_data[config.user_claims_key]):
3836
raise UserClaimsVerificationError('User claims verification failed')
39-
_load_user(jwt_data[config.identity_claim])
37+
_load_user(jwt_data[config.identity_claim_key])
4038
return fn(*args, **kwargs)
4139
return wrapper
4240

@@ -60,9 +58,9 @@ def wrapper(*args, **kwargs):
6058
try:
6159
jwt_data = _decode_jwt_from_request(request_type='access')
6260
ctx_stack.top.jwt = jwt_data
63-
if not verify_token_claims(jwt_data[config.user_claims]):
61+
if not verify_token_claims(jwt_data[config.user_claims_key]):
6462
raise UserClaimsVerificationError('User claims verification failed')
65-
_load_user(jwt_data[config.identity_claim])
63+
_load_user(jwt_data[config.identity_claim_key])
6664
except (NoAuthorizationError, InvalidHeaderError):
6765
pass
6866
return fn(*args, **kwargs)
@@ -85,9 +83,9 @@ def wrapper(*args, **kwargs):
8583
ctx_stack.top.jwt = jwt_data
8684
if not jwt_data['fresh']:
8785
raise FreshTokenRequired('Fresh token required')
88-
if not verify_token_claims(jwt_data[config.user_claims]):
86+
if not verify_token_claims(jwt_data[config.user_claims_key]):
8987
raise UserClaimsVerificationError('User claims verification failed')
90-
_load_user(jwt_data[config.identity_claim])
88+
_load_user(jwt_data[config.identity_claim_key])
9189
return fn(*args, **kwargs)
9290
return wrapper
9391

@@ -103,7 +101,7 @@ def jwt_refresh_token_required(fn):
103101
def wrapper(*args, **kwargs):
104102
jwt_data = _decode_jwt_from_request(request_type='refresh')
105103
ctx_stack.top.jwt = jwt_data
106-
_load_user(jwt_data[config.identity_claim])
104+
_load_user(jwt_data[config.identity_claim_key])
107105
return fn(*args, **kwargs)
108106
return wrapper
109107

@@ -148,20 +146,14 @@ def _decode_jwt_from_headers():
148146
if len(parts) != 1:
149147
msg = "Bad {} header. Expected value '<JWT>'".format(header_name)
150148
raise InvalidHeaderError(msg)
151-
token = parts[0]
149+
encoded_token = parts[0]
152150
else:
153151
if parts[0] != header_type or len(parts) != 2:
154152
msg = "Bad {} header. Expected value '{} <JWT>'".format(header_name, header_type)
155153
raise InvalidHeaderError(msg)
156-
token = parts[1]
154+
encoded_token = parts[1]
157155

158-
return decode_jwt(
159-
encoded_token=token,
160-
secret=config.decode_key,
161-
algorithm=config.algorithm,
162-
csrf=False,
163-
identity_claim=config.identity_claim
164-
)
156+
return decode_token(encoded_token)
165157

166158

167159
def _decode_jwt_from_cookies(request_type):
@@ -172,29 +164,18 @@ def _decode_jwt_from_cookies(request_type):
172164
cookie_key = config.refresh_cookie_name
173165
csrf_header_key = config.refresh_csrf_header_name
174166

167+
if config.csrf_protect and request.method in config.csrf_request_methods:
168+
csrf_value = request.headers.get(csrf_header_key, None)
169+
if not csrf_value:
170+
raise CSRFError("Missing CSRF token in headers")
171+
else:
172+
csrf_value = None
173+
175174
encoded_token = request.cookies.get(cookie_key)
176175
if not encoded_token:
177176
raise NoAuthorizationError('Missing cookie "{}"'.format(cookie_key))
178177

179-
decoded_token = decode_jwt(
180-
encoded_token=encoded_token,
181-
secret=config.decode_key,
182-
algorithm=config.algorithm,
183-
csrf=config.csrf_protect,
184-
identity_claim=config.identity_claim
185-
)
186-
187-
# Verify csrf double submit tokens match if required
188-
if config.csrf_protect and request.method in config.csrf_request_methods:
189-
csrf_token_in_token = decoded_token['csrf']
190-
csrf_token_in_header = request.headers.get(csrf_header_key, None)
191-
192-
if not csrf_token_in_header:
193-
raise CSRFError("Missing CSRF token in headers")
194-
if not safe_str_cmp(csrf_token_in_header, csrf_token_in_token):
195-
raise CSRFError("CSRF double submit tokens do not match")
196-
197-
return decoded_token
178+
return decode_token(encoded_token, csrf_value=csrf_value)
198179

199180

200181
def _decode_jwt_from_request(request_type):

tests/test_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def test_default_configs(app):
5252

5353
assert config.cookie_max_age is None
5454

55-
assert config.identity_claim == 'identity'
56-
assert config.user_claims == 'user_claims'
55+
assert config.identity_claim_key == 'identity'
56+
assert config.user_claims_key == 'user_claims'
5757

5858

5959
def test_override_configs(app):
@@ -125,8 +125,8 @@ def test_override_configs(app):
125125

126126
assert config.cookie_max_age == 2147483647
127127

128-
assert config.identity_claim == 'foo'
129-
assert config.user_claims == 'bar'
128+
assert config.identity_claim_key == 'foo'
129+
assert config.user_claims_key == 'bar'
130130

131131

132132
# noinspection PyStatementEffect

tests/test_decode_tokens.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def default_access_token(app):
2929
with app.test_request_context():
3030
return {
3131
'jti': '1234',
32-
config.identity_claim: 'username',
32+
config.identity_claim_key: 'username',
3333
'type': 'access',
3434
'fresh': True,
3535
'csrf': 'abcd'
@@ -49,9 +49,9 @@ def empty_user_loader_return(identity):
4949
with app.test_request_context():
5050
token = create_access_token('username')
5151
pure_decoded = jwt.decode(token, config.decode_key, algorithms=[config.algorithm])
52-
assert config.user_claims not in pure_decoded
52+
assert config.user_claims_key not in pure_decoded
5353
extension_decoded = decode_token(token)
54-
assert config.user_claims in extension_decoded
54+
assert config.user_claims_key in extension_decoded
5555

5656

5757
@pytest.mark.parametrize("missing_claim", ['jti', 'type', 'identity', 'fresh', 'csrf'])
@@ -61,7 +61,7 @@ def test_missing_jti_claim(app, default_access_token, missing_claim):
6161

6262
with pytest.raises(JWTDecodeError):
6363
with app.test_request_context():
64-
decode_token(missing_jwt_token)
64+
decode_token(missing_jwt_token, csrf_value='abcd')
6565

6666

6767
def test_bad_token_type(app, default_access_token):

0 commit comments

Comments
 (0)