Skip to content

Commit e182953

Browse files
iamajayvimalloc
authored andcommitted
added support to pass additional headers in JWT encoding the token (#276)
1 parent 7f39a44 commit e182953

9 files changed

+244
-34
lines changed

flask_jwt_extended/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from .jwt_manager import JWTManager
2-
from .view_decorators import (
3-
fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required,
4-
verify_fresh_jwt_in_request, verify_jwt_in_request,
5-
verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request
6-
)
72
from .utils import (
83
create_access_token, create_refresh_token, current_user, decode_token,
94
get_csrf_token, get_current_user, get_jti, get_jwt_claims, get_jwt_identity,
105
get_raw_jwt, set_access_cookies, set_refresh_cookies, unset_access_cookies,
11-
unset_jwt_cookies, unset_refresh_cookies
6+
unset_jwt_cookies, unset_refresh_cookies, get_unverified_jwt_headers,
7+
get_raw_jwt_header
8+
)
9+
from .view_decorators import (
10+
fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required,
11+
verify_fresh_jwt_in_request, verify_jwt_in_request,
12+
verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request
1213
)
1314

1415
__version__ = '3.23.0'

flask_jwt_extended/default_callbacks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def default_user_claims_callback(userdata):
2222
return {}
2323

2424

25+
def default_jwt_headers_callback(default_headers):
26+
"""
27+
By default header typically consists of two parts: the type of the token,
28+
which is JWT, and the signing algorithm being used, such as HMAC SHA256
29+
or RSA. But we don't set the default header here we set it as empty which
30+
further by default set while encoding the token
31+
:return: default we set None here
32+
"""
33+
return None
34+
35+
2536
def default_user_identity_callback(userdata):
2637
"""
2738
By default, we use the passed in object directly as the jwt identity.

flask_jwt_extended/jwt_manager.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ExpiredSignatureError, InvalidTokenError, InvalidAudienceError,
66
InvalidIssuerError, DecodeError
77
)
8+
89
try:
910
from flask import _app_ctx_stack as ctx_stack
1011
except ImportError: # pragma: no cover
@@ -22,8 +23,8 @@
2223
default_unauthorized_callback, default_needs_fresh_token_callback,
2324
default_revoked_token_callback, default_user_loader_error_callback,
2425
default_claims_verification_callback, default_verify_claims_failed_callback,
25-
default_decode_key_callback, default_encode_key_callback
26-
)
26+
default_decode_key_callback, default_encode_key_callback,
27+
default_jwt_headers_callback)
2728
from flask_jwt_extended.tokens import (
2829
encode_refresh_token, encode_access_token
2930
)
@@ -64,6 +65,7 @@ def __init__(self, app=None):
6465
self._verify_claims_failed_callback = default_verify_claims_failed_callback
6566
self._decode_key_callback = default_decode_key_callback
6667
self._encode_key_callback = default_encode_key_callback
68+
self._jwt_additional_header_callback = default_jwt_headers_callback
6769

6870
# Register this extension with the flask app now (if it is provided)
6971
if app is not None:
@@ -454,13 +456,33 @@ def encode_key_loader(self, callback):
454456
self._encode_key_callback = callback
455457
return callback
456458

457-
def _create_refresh_token(self, identity, expires_delta=None, user_claims=None):
459+
def additional_headers_loader(self, callback):
460+
"""
461+
This decorator sets the callback function for adding custom headers to an
462+
access token when :func:`~flask_jwt_extended.create_access_token` is
463+
called. By default, two headers will be added the type of the token, which is JWT,
464+
and the signing algorithm being used, such as HMAC SHA256 or RSA.
465+
466+
*HINT*: The callback function must be a function that takes **no** argument,
467+
which is the object passed into
468+
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
469+
claims you want included in the access tokens. This returned claims
470+
must be *JSON serializable*.
471+
"""
472+
self._jwt_additional_header_callback = callback
473+
return callback
474+
475+
def _create_refresh_token(self, identity, expires_delta=None, user_claims=None,
476+
headers=None):
458477
if expires_delta is None:
459478
expires_delta = config.refresh_expires
460479

461480
if user_claims is None and config.user_claims_in_refresh_token:
462481
user_claims = self._user_claims_callback(identity)
463482

483+
if headers is None:
484+
headers = self._jwt_additional_header_callback(identity)
485+
464486
refresh_token = encode_refresh_token(
465487
identity=self._user_identity_callback(identity),
466488
secret=self._encode_key_callback(identity),
@@ -470,17 +492,22 @@ def _create_refresh_token(self, identity, expires_delta=None, user_claims=None):
470492
csrf=config.csrf_protect,
471493
identity_claim_key=config.identity_claim_key,
472494
user_claims_key=config.user_claims_key,
473-
json_encoder=config.json_encoder
495+
json_encoder=config.json_encoder,
496+
headers=headers
474497
)
475498
return refresh_token
476499

477-
def _create_access_token(self, identity, fresh=False, expires_delta=None, user_claims=None):
500+
def _create_access_token(self, identity, fresh=False, expires_delta=None,
501+
user_claims=None, headers=None):
478502
if expires_delta is None:
479503
expires_delta = config.access_expires
480504

481505
if user_claims is None:
482506
user_claims = self._user_claims_callback(identity)
483507

508+
if headers is None:
509+
headers = self._jwt_additional_header_callback(identity)
510+
484511
access_token = encode_access_token(
485512
identity=self._user_identity_callback(identity),
486513
secret=self._encode_key_callback(identity),
@@ -491,6 +518,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None, user_c
491518
csrf=config.csrf_protect,
492519
identity_claim_key=config.identity_claim_key,
493520
user_claims_key=config.user_claims_key,
494-
json_encoder=config.json_encoder
521+
json_encoder=config.json_encoder,
522+
headers=headers
495523
)
496524
return access_token

flask_jwt_extended/tokens.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import datetime
22
import uuid
3-
43
from calendar import timegm
54

65
import jwt
@@ -14,7 +13,7 @@ def _create_csrf_token():
1413

1514

1615
def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
17-
json_encoder=None):
16+
json_encoder=None, headers=None):
1817
uid = _create_csrf_token()
1918
now = datetime.datetime.utcnow()
2019
token_data = {
@@ -28,13 +27,13 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
2827
token_data['exp'] = now + expires_delta
2928
token_data.update(additional_token_data)
3029
encoded_token = jwt.encode(token_data, secret, algorithm,
31-
json_encoder=json_encoder).decode('utf-8')
30+
json_encoder=json_encoder, headers=headers).decode('utf-8')
3231
return encoded_token
3332

3433

3534
def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
3635
user_claims, csrf, identity_claim_key, user_claims_key,
37-
json_encoder=None):
36+
json_encoder=None, headers=None):
3837
"""
3938
Creates a new encoded (utf-8) access token.
4039
@@ -54,6 +53,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
5453
(boolean)
5554
:param identity_claim_key: Which key should be used to store the identity
5655
:param user_claims_key: Which key should be used to store the user claims
56+
:param headers: valid dict for specifying additional headers in JWT header section
5757
:return: Encoded access token
5858
"""
5959

@@ -74,12 +74,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
7474
if csrf:
7575
token_data['csrf'] = _create_csrf_token()
7676
return _encode_jwt(token_data, expires_delta, secret, algorithm,
77-
json_encoder=json_encoder)
77+
json_encoder=json_encoder, headers=headers)
7878

7979

8080
def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
8181
csrf, identity_claim_key, user_claims_key,
82-
json_encoder=None):
82+
json_encoder=None, headers=None):
8383
"""
8484
Creates a new encoded (utf-8) refresh token.
8585
@@ -95,6 +95,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
9595
(boolean)
9696
:param identity_claim_key: Which key should be used to store the identity
9797
:param user_claims_key: Which key should be used to store the user claims
98+
:param headers: valid dict for specifying additional headers in JWT header section
9899
:return: Encoded refresh token
99100
"""
100101
token_data = {
@@ -109,7 +110,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
109110
if csrf:
110111
token_data['csrf'] = _create_csrf_token()
111112
return _encode_jwt(token_data, expires_delta, secret, algorithm,
112-
json_encoder=json_encoder)
113+
json_encoder=json_encoder, headers=headers)
113114

114115

115116
def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,

flask_jwt_extended/utils.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from warnings import warn
2+
13
from flask import current_app
2-
from werkzeug.local import LocalProxy
34
from jwt import ExpiredSignatureError
4-
from warnings import warn
5+
from werkzeug.local import LocalProxy
56

67
try:
78
from flask import _app_ctx_stack as ctx_stack
@@ -29,6 +30,15 @@ def get_raw_jwt():
2930
return getattr(ctx_stack.top, 'jwt', {})
3031

3132

33+
def get_raw_jwt_header():
34+
"""
35+
In a protected endpoint, this will return the python dictionary which has
36+
the JWT headers values. If no
37+
JWT is currently present, an empty dict is returned instead.
38+
"""
39+
return getattr(ctx_stack.top, 'jwt_header', {})
40+
41+
3242
def get_jwt_identity():
3343
"""
3444
In a protected endpoint, this will return the identity of the JWT that is
@@ -132,7 +142,8 @@ def _get_jwt_manager():
132142
"application before using this method")
133143

134144

135-
def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None):
145+
def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None,
146+
headers=None):
136147
"""
137148
Create a new access token.
138149
@@ -153,13 +164,17 @@ def create_access_token(identity, fresh=False, expires_delta=None, user_claims=N
153164
'JWT_ACCESS_TOKEN_EXPIRES` config value
154165
(see :ref:`Configuration Options`)
155166
:param user_claims: Optional JSON serializable to override user claims.
167+
:param headers: Optional, valid dict for specifying additional headers in JWT
168+
header section
156169
:return: An encoded access token
157170
"""
158171
jwt_manager = _get_jwt_manager()
159-
return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims)
172+
return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims,
173+
headers=headers)
160174

161175

162-
def create_refresh_token(identity, expires_delta=None, user_claims=None):
176+
def create_refresh_token(identity, expires_delta=None, user_claims=None,
177+
headers=None):
163178
"""
164179
Creates a new refresh token.
165180
@@ -175,10 +190,13 @@ def create_refresh_token(identity, expires_delta=None, user_claims=None):
175190
'JWT_REFRESH_TOKEN_EXPIRES` config value
176191
(see :ref:`Configuration Options`)
177192
:param user_claims: Optional JSON serializable to override user claims.
193+
:param headers: Optional, valid dict for specifying additional headers in JWT
194+
header section
178195
:return: An encoded refresh token
179196
"""
180197
jwt_manager = _get_jwt_manager()
181-
return jwt_manager._create_refresh_token(identity, expires_delta, user_claims)
198+
return jwt_manager._create_refresh_token(identity, expires_delta, user_claims,
199+
headers=headers)
182200

183201

184202
def has_user_loader():
@@ -396,3 +414,15 @@ def unset_refresh_cookies(response):
396414
domain=config.cookie_domain,
397415
path=config.refresh_csrf_cookie_path,
398416
samesite=config.cookie_samesite)
417+
418+
419+
def get_unverified_jwt_headers(encoded_token):
420+
"""
421+
Returns the Headers of an encoded JWT without verifying the actual signature of JWT.
422+
Note: The signature is not verified so the header parameters
423+
should not be fully trusted until signature verification is complete
424+
425+
:param encoded_token: The encoded JWT to get the Header from.
426+
:return: JWT header parameters as python dict()
427+
"""
428+
return jwt.get_unverified_header(encoded_token)

flask_jwt_extended/view_decorators.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from flask_jwt_extended.utils import (
2020
decode_token, has_user_loader, user_loader, verify_token_claims,
21-
verify_token_not_blacklisted, verify_token_type
21+
verify_token_not_blacklisted, verify_token_type, get_unverified_jwt_headers
2222
)
2323

2424

@@ -29,8 +29,9 @@ def verify_jwt_in_request():
2929
no token or if the token is invalid.
3030
"""
3131
if request.method not in config.exempt_methods:
32-
jwt_data = _decode_jwt_from_request(request_type='access')
32+
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
3333
ctx_stack.top.jwt = jwt_data
34+
ctx_stack.top.jwt_header = jwt_header
3435
verify_token_claims(jwt_data)
3536
_load_user(jwt_data[config.identity_claim_key])
3637

@@ -48,8 +49,9 @@ def verify_jwt_in_request_optional():
4849
"""
4950
try:
5051
if request.method not in config.exempt_methods:
51-
jwt_data = _decode_jwt_from_request(request_type='access')
52+
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
5253
ctx_stack.top.jwt = jwt_data
54+
ctx_stack.top.jwt_header = jwt_header
5355
verify_token_claims(jwt_data)
5456
_load_user(jwt_data[config.identity_claim_key])
5557
except (NoAuthorizationError, InvalidHeaderError):
@@ -63,8 +65,9 @@ def verify_fresh_jwt_in_request():
6365
token is not marked as fresh.
6466
"""
6567
if request.method not in config.exempt_methods:
66-
jwt_data = _decode_jwt_from_request(request_type='access')
68+
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
6769
ctx_stack.top.jwt = jwt_data
70+
ctx_stack.top.jwt_header = jwt_header
6871
fresh = jwt_data['fresh']
6972
if isinstance(fresh, bool):
7073
if not fresh:
@@ -83,8 +86,9 @@ def verify_jwt_refresh_token_in_request():
8386
exception if there is no token or the token is invalid.
8487
"""
8588
if request.method not in config.exempt_methods:
86-
jwt_data = _decode_jwt_from_request(request_type='refresh')
89+
jwt_data, jwt_header = _decode_jwt_from_request(request_type='refresh')
8790
ctx_stack.top.jwt = jwt_data
91+
ctx_stack.top.jwt_header = jwt_header
8892
_load_user(jwt_data[config.identity_claim_key])
8993

9094

@@ -283,10 +287,12 @@ def _decode_jwt_from_request(request_type):
283287
# in one place to be valid (not every location).
284288
errors = []
285289
decoded_token = None
290+
jwt_header = None
286291
for get_encoded_token_function in get_encoded_token_functions:
287292
try:
288293
encoded_token, csrf_token = get_encoded_token_function()
289294
decoded_token = decode_token(encoded_token, csrf_token)
295+
jwt_header = get_unverified_jwt_headers(encoded_token)
290296
break
291297
except NoAuthorizationError as e:
292298
errors.append(str(e))
@@ -309,4 +315,4 @@ def _decode_jwt_from_request(request_type):
309315

310316
verify_token_type(decoded_token, expected_type=request_type)
311317
verify_token_not_blacklisted(decoded_token, request_type)
312-
return decoded_token
318+
return decoded_token, jwt_header

tests/test_decode_tokens.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from flask_jwt_extended import (
1515
JWTManager, create_access_token, decode_token, create_refresh_token,
16-
get_jti
16+
get_jti, get_unverified_jwt_headers
1717
)
1818
from flask_jwt_extended.config import config
1919
from flask_jwt_extended.exceptions import JWTDecodeError
@@ -286,3 +286,12 @@ def test_malformed_token(app):
286286
with pytest.raises(DecodeError):
287287
with app.test_request_context():
288288
decode_token(invalid_token)
289+
290+
291+
def test_jwt_headers(app):
292+
jwt_header = {"foo": "bar"}
293+
with app.test_request_context():
294+
access_token = create_access_token('username', headers=jwt_header)
295+
refresh_token = create_refresh_token('username', headers=jwt_header)
296+
assert get_unverified_jwt_headers(access_token)["foo"] == "bar"
297+
assert get_unverified_jwt_headers(refresh_token)["foo"] == "bar"

0 commit comments

Comments
 (0)