Skip to content

Commit 88a628e

Browse files
authored
Compatibility changes for upcoming flask version 2.3 (#493)
* Switched from _request_ctx_stack.top to flask.g * Handle JSONEncoder changes
1 parent 37634ed commit 88a628e

File tree

7 files changed

+93
-32
lines changed

7 files changed

+93
-32
lines changed

flask_jwt_extended/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from datetime import timedelta
33
from datetime import timezone
4+
from json import JSONEncoder
45
from typing import Iterable
56
from typing import List
67
from typing import Optional
@@ -9,9 +10,9 @@
910
from typing import Union
1011

1112
from flask import current_app
12-
from flask.json import JSONEncoder
1313
from jwt.algorithms import requires_cryptography
1414

15+
from flask_jwt_extended.internal_utils import get_json_encoder
1516
from flask_jwt_extended.typing import ExpiresDelta
1617

1718

@@ -284,7 +285,7 @@ def error_msg_key(self) -> str:
284285

285286
@property
286287
def json_encoder(self) -> Type[JSONEncoder]:
287-
return current_app.json_encoder
288+
return get_json_encoder(current_app)
288289

289290
@property
290291
def decode_audience(self) -> Union[str, Iterable[str]]:

flask_jwt_extended/internal_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1+
import json
12
from typing import Any
3+
from typing import Type
24
from typing import TYPE_CHECKING
35

46
from flask import current_app
7+
from flask import Flask
58

69
from flask_jwt_extended.exceptions import RevokedTokenError
710
from flask_jwt_extended.exceptions import UserClaimsVerificationError
811
from flask_jwt_extended.exceptions import WrongTokenError
912

13+
try:
14+
from flask.json.provider import DefaultJSONProvider
15+
16+
HAS_JSON_PROVIDER = True
17+
except ModuleNotFoundError: # pragma: no cover
18+
# The flask.json.provider module was added in Flask 2.2.
19+
# Further details are handled in get_json_encoder.
20+
HAS_JSON_PROVIDER = False
21+
22+
1023
if TYPE_CHECKING: # pragma: no cover
1124
from flask_jwt_extended import JWTManager
1225

@@ -51,3 +64,35 @@ def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None:
5164
if not jwt_manager._token_verification_callback(jwt_header, jwt_data):
5265
error_msg = "User claims verification failed"
5366
raise UserClaimsVerificationError(error_msg, jwt_header, jwt_data)
67+
68+
69+
class JSONEncoder(json.JSONEncoder):
70+
"""A JSON encoder which uses the app.json_provider_class for the default"""
71+
72+
def default(self, o: Any) -> Any:
73+
# If the registered JSON provider does not implement a default classmethod
74+
# use the method defined by the DefaultJSONProvider
75+
default = getattr(
76+
current_app.json_provider_class, "default", DefaultJSONProvider.default
77+
)
78+
return default(o)
79+
80+
81+
def get_json_encoder(app: Flask) -> Type[json.JSONEncoder]:
82+
"""Get the JSON Encoder for the provided flask app
83+
84+
Starting with flask version 2.2 the flask application provides a
85+
interface to register a custom JSON Encoder/Decoder under the json_provider_class.
86+
As this interface is not compatible with the standard JSONEncoder, the `default`
87+
method of the class is wrapped.
88+
89+
Lookup Order:
90+
- app.json_encoder - For Flask < 2.2
91+
- app.json_provider_class.default
92+
- flask.json.provider.DefaultJSONProvider.default
93+
94+
"""
95+
if not HAS_JSON_PROVIDER: # pragma: no cover
96+
return app.json_encoder
97+
98+
return JSONEncoder

flask_jwt_extended/tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from datetime import timedelta
44
from datetime import timezone
55
from hmac import compare_digest
6+
from json import JSONEncoder
67
from typing import Any
78
from typing import Iterable
89
from typing import List
910
from typing import Type
1011
from typing import Union
1112

1213
import jwt
13-
from flask.json import JSONEncoder
1414

1515
from flask_jwt_extended.exceptions import CSRFError
1616
from flask_jwt_extended.exceptions import JWTDecodeError

flask_jwt_extended/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44

55
import jwt
6-
from flask import _request_ctx_stack
6+
from flask import g
77
from flask import Response
88
from werkzeug.local import LocalProxy
99

@@ -23,7 +23,7 @@ def get_jwt() -> dict:
2323
:return:
2424
The payload (claims) of the JWT in the current request
2525
"""
26-
decoded_jwt = getattr(_request_ctx_stack.top, "jwt", None)
26+
decoded_jwt = g.get("_jwt_extended_jwt", None)
2727
if decoded_jwt is None:
2828
raise RuntimeError(
2929
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
@@ -41,7 +41,7 @@ def get_jwt_header() -> dict:
4141
:return:
4242
The headers of the JWT in the current request
4343
"""
44-
decoded_header = getattr(_request_ctx_stack.top, "jwt_header", None)
44+
decoded_header = g.get("_jwt_extended_jwt_header", None)
4545
if decoded_header is None:
4646
raise RuntimeError(
4747
"You must call `@jwt_required()` or `verify_jwt_in_request()` "
@@ -73,7 +73,7 @@ def get_jwt_request_location() -> Optional[str]:
7373
The location of the JWT in the current request; e.g., "cookies",
7474
"query-string", "headers", or "json"
7575
"""
76-
return getattr(_request_ctx_stack.top, "jwt_location", None)
76+
return g.get("_jwt_extended_jwt_location", None)
7777

7878

7979
def get_current_user() -> Any:
@@ -91,7 +91,7 @@ def get_current_user() -> Any:
9191
The current user object for the JWT in the current request
9292
"""
9393
get_jwt() # Raise an error if not in a decorated context
94-
jwt_user_dict = getattr(_request_ctx_stack.top, "jwt_user", None)
94+
jwt_user_dict = g.get("_jwt_extended_jwt_user", None)
9595
if jwt_user_dict is None:
9696
raise RuntimeError(
9797
"You must provide a `@jwt.user_lookup_loader` callback to use "

flask_jwt_extended/view_decorators.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from typing import Tuple
99
from typing import Union
1010

11-
from flask import _request_ctx_stack
1211
from flask import current_app
12+
from flask import g
1313
from flask import request
1414
from werkzeug.exceptions import BadRequest
1515

@@ -85,10 +85,6 @@ def verify_jwt_in_request(
8585
if request.method in config.exempt_methods:
8686
return None
8787

88-
# Should be impossible to hit, this makes mypy checks happy
89-
if not _request_ctx_stack.top: # pragma: no cover
90-
raise RuntimeError("No _request_ctx_stack.top present, aborting")
91-
9288
try:
9389
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
9490
locations, fresh, refresh=refresh, verify_type=verify_type
@@ -97,18 +93,18 @@ def verify_jwt_in_request(
9793
except NoAuthorizationError:
9894
if not optional:
9995
raise
100-
_request_ctx_stack.top.jwt = {}
101-
_request_ctx_stack.top.jwt_header = {}
102-
_request_ctx_stack.top.jwt_user = {"loaded_user": None}
103-
_request_ctx_stack.top.jwt_location = None
96+
g._jwt_extended_jwt = {}
97+
g._jwt_extended_jwt_header = {}
98+
g._jwt_extended_jwt_user = {"loaded_user": None}
99+
g._jwt_extended_jwt_location = None
104100
return None
105101

106102
# Save these at the very end so that they are only saved in the requet
107103
# context if the token is valid and all callbacks succeed
108-
_request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data)
109-
_request_ctx_stack.top.jwt_header = jwt_header
110-
_request_ctx_stack.top.jwt = jwt_data
111-
_request_ctx_stack.top.jwt_location = jwt_location
104+
g._jwt_extended_jwt_user = _load_user(jwt_header, jwt_data)
105+
g._jwt_extended_jwt_header = jwt_header
106+
g._jwt_extended_jwt = jwt_data
107+
g._jwt_extended_jwt_location = jwt_location
112108

113109
return jwt_header, jwt_data
114110

tests/test_config.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import json
2+
from datetime import date
13
from datetime import timedelta
24

35
import pytest
46
from dateutil.relativedelta import relativedelta
7+
from flask import __version__ as flask_version
58
from flask import Flask
6-
from flask.json import JSONEncoder
79

810
from flask_jwt_extended import JWTManager
911
from flask_jwt_extended.config import config
12+
from flask_jwt_extended.internal_utils import JSONEncoder
13+
14+
15+
flask_version_tuple = tuple(map(int, flask_version.split(".")))
1016

1117

1218
@pytest.fixture(scope="function")
@@ -65,8 +71,6 @@ def test_default_configs(app):
6571

6672
assert config.identity_claim_key == "sub"
6773

68-
assert config.json_encoder is app.json_encoder
69-
7074
assert config.error_msg_key == "msg"
7175

7276

@@ -112,11 +116,6 @@ def test_override_configs(app, delta_func):
112116

113117
app.config["JWT_ERROR_MESSAGE_KEY"] = "message"
114118

115-
class CustomJSONEncoder(JSONEncoder):
116-
pass
117-
118-
app.json_encoder = CustomJSONEncoder
119-
120119
with app.test_request_context():
121120
assert config.token_location == ["cookies", "query_string", "json"]
122121
assert config.jwt_in_query_string is True
@@ -162,11 +161,29 @@ class CustomJSONEncoder(JSONEncoder):
162161

163162
assert config.identity_claim_key == "foo"
164163

165-
assert config.json_encoder is CustomJSONEncoder
166-
167164
assert config.error_msg_key == "message"
168165

169166

167+
@pytest.mark.skipif(
168+
flask_version_tuple >= (2, 2, 0), reason="Only applies to Flask <= 2.2.0"
169+
)
170+
def test_config_json_encoder_flask21(app):
171+
with app.test_request_context():
172+
assert config.json_encoder == app.json_encoder
173+
dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder)
174+
assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}'
175+
176+
177+
@pytest.mark.skipif(
178+
flask_version_tuple < (2, 2, 0), reason="Only applies to Flask > 2.2.0"
179+
)
180+
def test_config_json_encoder_flask(app):
181+
with app.test_request_context():
182+
assert config.json_encoder == JSONEncoder
183+
dump = json.dumps({"d": date(2022, 8, 12)}, cls=config.json_encoder)
184+
assert dump == '{"d": "Fri, 12 Aug 2022 00:00:00 GMT"}'
185+
186+
170187
def test_tokens_never_expire(app):
171188
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = False
172189
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = False

tox.ini

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# and then run "tox" from this directory.
55

66
[tox]
7-
envlist = py37,py38,py39,py310,pypy3.9,mypy,coverage,style,docs
7+
envlist = py{37,38,39,310}-{flask21,flask},pypy3.9,mypy,coverage,style,docs
88

99
[testenv]
1010
commands =
@@ -13,6 +13,8 @@ deps =
1313
pytest
1414
cryptography
1515
python-dateutil
16+
flask21: Flask>=2.1,<2.2
17+
flask: Flask>=2.2
1618

1719
[testenv:mypy]
1820
commands =

0 commit comments

Comments
 (0)