Skip to content

Commit aca3d28

Browse files
committed
add jwt_optional view decorator and corresponding tests
1 parent 9bfa900 commit aca3d28

File tree

3 files changed

+205
-2
lines changed

3 files changed

+205
-2
lines changed

flask_jwt_extended/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .jwt_manager import JWTManager
22
from .view_decorators import (
3-
jwt_required, fresh_jwt_required, jwt_refresh_token_required
3+
jwt_required, fresh_jwt_required, jwt_refresh_token_required,
4+
jwt_optional
45
)
56
from .utils import (
67
create_refresh_token, create_access_token, get_jwt_identity,

flask_jwt_extended/view_decorators.py

+26
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@ def wrapper(*args, **kwargs):
3636
return wrapper
3737

3838

39+
def jwt_optional(fn):
40+
"""
41+
If you decorate a view with this, it will check the request for a valid
42+
JWT and put it into the Flask application context before calling the view.
43+
If no authorization header is present, the view will be called without the
44+
application context being changed. Other authentication errors are not
45+
affected.
46+
47+
:param fn: The view function to decorate
48+
"""
49+
@wraps(fn)
50+
def wrapper(*args, **kwargs):
51+
try:
52+
# If an acceptable JWT is found in the request, put it into
53+
# the application context
54+
jwt_data = _decode_jwt_from_request(request_type='access')
55+
ctx_stack.top.jwt = jwt_data
56+
except NoAuthorizationError:
57+
# Allow request to proceed if no authorization header is present
58+
# in the request, but don't modify application context
59+
pass
60+
# Return the decorated function in either case
61+
return fn(*args, **kwargs)
62+
return wrapper
63+
64+
3965
def fresh_jwt_required(fn):
4066
"""
4167
If you decorate a vew with this, it will ensure that the requester has a

tests/test_protected_endpoints.py

+177-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_jwt_identity, set_refresh_cookies, set_access_cookies, unset_jwt_cookies
1212
from flask_jwt_extended import JWTManager, create_refresh_token, \
1313
jwt_refresh_token_required, create_access_token, fresh_jwt_required, \
14-
jwt_required, get_raw_jwt
14+
jwt_optional, jwt_required, get_raw_jwt
1515

1616

1717
class TestEndpoints(unittest.TestCase):
@@ -55,6 +55,14 @@ def protected():
5555
def fresh_protected():
5656
return jsonify({'msg': "fresh hello world"})
5757

58+
@self.app.route('/partially-protected')
59+
@jwt_optional
60+
def partially_protected():
61+
if get_jwt_identity():
62+
return jsonify({'msg': "protected hello world"})
63+
return jsonify({'msg': "unprotected hello world"})
64+
65+
5866
def _jwt_post(self, url, jwt):
5967
response = self.client.post(url, content_type='application/json',
6068
headers={'Authorization': 'Bearer {}'.format(jwt)})
@@ -124,6 +132,32 @@ def test_jwt_required(self):
124132
self.assertEqual(status, 200)
125133
self.assertEqual(data, {'msg': 'hello world'})
126134

135+
def test_jwt_optional_no_jwt(self):
136+
response = self.client.get('/partially-protected')
137+
data = json.loads(response.get_data(as_text=True))
138+
status = response.status_code
139+
self.assertEqual(status, 200)
140+
self.assertEqual(data, {'msg': 'unprotected hello world'})
141+
142+
def test_jwt_optional_with_jwt(self):
143+
response = self.client.post('/auth/login')
144+
data = json.loads(response.get_data(as_text=True))
145+
fresh_access_token = data['access_token']
146+
refresh_token = data['refresh_token']
147+
148+
# Test it works with a fresh token
149+
status, data = self._jwt_get('/partially-protected',
150+
fresh_access_token)
151+
self.assertEqual(data, {'msg': 'protected hello world'})
152+
self.assertEqual(status, 200)
153+
154+
# Test it works with a non-fresh access token
155+
_, data = self._jwt_post('/auth/refresh', refresh_token)
156+
non_fresh_token = data['access_token']
157+
status, data = self._jwt_get('/partially-protected', non_fresh_token)
158+
self.assertEqual(status, 200)
159+
self.assertEqual(data, {'msg': 'protected hello world'})
160+
127161
def test_jwt_required_wrong_token(self):
128162
response = self.client.post('/auth/login')
129163
data = json.loads(response.get_data(as_text=True))
@@ -133,6 +167,15 @@ def test_jwt_required_wrong_token(self):
133167
status, text = self._jwt_get('/protected', refresh_token)
134168
self.assertEqual(status, 422)
135169

170+
def test_jwt_optional_wrong_token(self):
171+
response = self.client.post('/auth/login')
172+
data = json.loads(response.get_data(as_text=True))
173+
refresh_token = data['refresh_token']
174+
175+
# Shouldn't work with a refresh token
176+
status, text = self._jwt_get('/partially-protected', refresh_token)
177+
self.assertEqual(status, 422)
178+
136179
def test_fresh_jwt_required(self):
137180
response = self.client.post('/auth/login')
138181
data = json.loads(response.get_data(as_text=True))
@@ -209,6 +252,38 @@ def test_bad_jwt_requests(self):
209252
self.assertEqual(status_code, 422)
210253
self.assertIn('msg', data)
211254

255+
def test_optional_bad_jwt_requests(self):
256+
response = self.client.post('/auth/login')
257+
data = json.loads(response.get_data(as_text=True))
258+
access_token = data['access_token']
259+
260+
# Test with missing type in authorization header
261+
auth_header = access_token
262+
response = self.client.get('/partially-protected',
263+
headers={'Authorization': auth_header})
264+
data = json.loads(response.get_data(as_text=True))
265+
status_code = response.status_code
266+
self.assertEqual(status_code, 422)
267+
self.assertIn('msg', data)
268+
269+
# Test with type not being Bearer in authorization header
270+
auth_header = "BANANA {}".format(access_token)
271+
response = self.client.get('/partially-protected',
272+
headers={'Authorization': auth_header})
273+
data = json.loads(response.get_data(as_text=True))
274+
status_code = response.status_code
275+
self.assertEqual(status_code, 422)
276+
self.assertIn('msg', data)
277+
278+
# Test with too many items in auth header
279+
auth_header = "Bearer {} BANANA".format(access_token)
280+
response = self.client.get('/partially-protected',
281+
headers={'Authorization': auth_header})
282+
data = json.loads(response.get_data(as_text=True))
283+
status_code = response.status_code
284+
self.assertEqual(status_code, 422)
285+
self.assertIn('msg', data)
286+
212287
def test_bad_tokens(self):
213288
# Test expired access token
214289
response = self.client.post('/auth/login')
@@ -267,6 +342,54 @@ def test_bad_tokens(self):
267342
self.assertEqual(status_code, 422)
268343
self.assertIn('msg', data)
269344

345+
def test_optional_jwt_bad_tokens(self):
346+
# Test expired access token
347+
response = self.client.post('/auth/login')
348+
data = json.loads(response.get_data(as_text=True))
349+
access_token = data['access_token']
350+
status_code, data = self._jwt_get('/partially-protected', access_token)
351+
self.assertEqual(status_code, 200)
352+
self.assertEqual(data, {'msg': 'protected hello world'})
353+
time.sleep(2)
354+
status_code, data = self._jwt_get('/partially-protected', access_token)
355+
self.assertEqual(status_code, 401)
356+
self.assertIn('msg', data)
357+
358+
# Test Bogus token
359+
auth_header = "Bearer {}".format('this_is_totally_an_access_token')
360+
response = self.client.get('/partially-protected',
361+
headers={'Authorization': auth_header})
362+
data = json.loads(response.get_data(as_text=True))
363+
status_code = response.status_code
364+
self.assertEqual(status_code, 422)
365+
self.assertIn('msg', data)
366+
367+
# Test token that was signed with a different key
368+
with self.app.test_request_context():
369+
token = encode_access_token('foo', 'newsecret', 'HS256',
370+
timedelta(minutes=5), True, {},
371+
csrf=False)
372+
auth_header = "Bearer {}".format(token)
373+
response = self.client.get('/partially-protected',
374+
headers={'Authorization': auth_header})
375+
data = json.loads(response.get_data(as_text=True))
376+
status_code = response.status_code
377+
self.assertEqual(status_code, 422)
378+
self.assertIn('msg', data)
379+
380+
# Test with valid token that is missing required claims
381+
now = datetime.utcnow()
382+
token_data = {'exp': now + timedelta(minutes=5)}
383+
encoded_token = jwt.encode(token_data, self.app.config['SECRET_KEY'],
384+
self.app.config['JWT_ALGORITHM']).decode('utf-8')
385+
auth_header = "Bearer {}".format(encoded_token)
386+
response = self.client.get('/partially-protected',
387+
headers={'Authorization': auth_header})
388+
data = json.loads(response.get_data(as_text=True))
389+
status_code = response.status_code
390+
self.assertEqual(status_code, 422)
391+
self.assertIn('msg', data)
392+
270393
def test_jwt_identity_claims(self):
271394
# Setup custom claims
272395
@self.jwt_manager.user_claims_loader
@@ -349,6 +472,43 @@ def test_different_headers(self):
349472
header_type='Bearer')
350473
self.assertIn('msg', data)
351474
self.assertEqual(status, 401)
475+
self.assertEqual(data, {'msg': 'Missing Auth Header'})
476+
477+
def test_different_headers_jwt_optional(self):
478+
response = self.client.post('/auth/login')
479+
data = json.loads(response.get_data(as_text=True))
480+
access_token = data['access_token']
481+
482+
self.app.config['JWT_HEADER_TYPE'] = 'JWT'
483+
status, data = self._jwt_get('/partially-protected', access_token,
484+
header_type='JWT')
485+
self.assertEqual(data, {'msg': 'protected hello world'})
486+
self.assertEqual(status, 200)
487+
488+
self.app.config['JWT_HEADER_TYPE'] = ''
489+
status, data = self._jwt_get('/partially-protected', access_token,
490+
header_type='')
491+
self.assertEqual(data, {'msg': 'protected hello world'})
492+
self.assertEqual(status, 200)
493+
494+
self.app.config['JWT_HEADER_TYPE'] = ''
495+
status, data = self._jwt_get('/partially-protected', access_token,
496+
header_type='Bearer')
497+
self.assertIn('msg', data)
498+
self.assertEqual(status, 422)
499+
500+
self.app.config['JWT_HEADER_TYPE'] = 'Bearer'
501+
self.app.config['JWT_HEADER_NAME'] = 'Auth'
502+
status, data = self._jwt_get('/partially-protected', access_token,
503+
header_name='Auth', header_type='Bearer')
504+
self.assertEqual(data, {'msg': 'protected hello world'})
505+
self.assertEqual(status, 200)
506+
507+
status, data = self._jwt_get('/partially-protected', access_token,
508+
header_name='Authorization',
509+
header_type='Bearer')
510+
self.assertEqual(status, 200)
511+
self.assertEqual(data, {'msg': 'unprotected hello world'})
352512

353513
def test_cookie_methods_fail_with_headers_configured(self):
354514
app = Flask(__name__)
@@ -401,6 +561,22 @@ def test_jwt_with_different_algorithm(self):
401561
self.assertEqual(status, 422)
402562
self.assertIn('msg', data)
403563

564+
def test_optional_jwt_with_different_algorithm(self):
565+
self.app.config['JWT_ALGORITHM'] = 'HS256'
566+
self.app.secret_key = 'test_secret'
567+
access_token = encode_access_token(
568+
identity='bobdobbs',
569+
secret='test_secret',
570+
algorithm='HS512',
571+
expires_delta=timedelta(minutes=5),
572+
fresh=True,
573+
user_claims={},
574+
csrf=False
575+
)
576+
status, data = self._jwt_get('/partially-protected', access_token)
577+
self.assertEqual(status, 422)
578+
self.assertIn('msg', data)
579+
404580

405581
class TestEndpointsWithCookies(unittest.TestCase):
406582

0 commit comments

Comments
 (0)