Skip to content

Commit 13b4c47

Browse files
committed
Improve set_rollback() behaviour
\## Description Fixes #6921. Added tests that fail before and pass afterwards. Remove the check for `connection.in_atomic_block` to determine if the current request is under a `transaction.atomic` from `ATOMIC_REQUESTS`. Instead, duplicate the method that Django itself uses [in BaseHandler](https://github.com/django/django/blob/964dd4f4f208722d8993a35c1ff047d353cea1ea/django/core/handlers/base.py#L64). This requires fetching the actual view function from `as_view()`, as seen by the URL resolver / BaseHandler. Since this requires `request`, I've also changed the accesses in `get_exception_handler_context` to be direct attribute accesses rather than `getattr()`. It seems the `getattr` defaults not accessible since `self.request`, `self.args`, and `self.kwargs` are always set in `dispatch()` before `handle_exception()` can ever be called. This is useful since `request` is always needed for the new `set_rollback` logic.
1 parent 89ac0a1 commit 13b4c47

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

rest_framework/views.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
from django.conf import settings
55
from django.core.exceptions import PermissionDenied
6-
from django.db import connection, models, transaction
6+
from django.db import connections, models, transaction
77
from django.http import Http404
88
from django.http.response import HttpResponseBase
99
from django.utils.cache import cc_delim_re, patch_vary_headers
@@ -62,10 +62,19 @@ def get_view_description(view, html=False):
6262
return description
6363

6464

65-
def set_rollback():
66-
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
67-
if atomic_requests and connection.in_atomic_block:
68-
transaction.set_rollback(True)
65+
def set_rollback(request):
66+
# We need the actual view func returned by the URL resolver which gets used
67+
# by Django's BaseHandler to determine `non_atomic_requests`. Be cautious
68+
# when fetching it though as it won't be set when views are tested with
69+
# requessts from a RequestFactory.
70+
try:
71+
non_atomic_requests = request.resolver_match.func._non_atomic_requests
72+
except AttributeError:
73+
non_atomic_requests = set()
74+
75+
for db in connections.all():
76+
if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
77+
transaction.set_rollback(True, using=db.alias)
6978

7079

7180
def exception_handler(exc, context):
@@ -95,7 +104,7 @@ def exception_handler(exc, context):
95104
else:
96105
data = {'detail': exc.detail}
97106

98-
set_rollback()
107+
set_rollback(context['request'])
99108
return Response(data, status=exc.status_code, headers=headers)
100109

101110
return None
@@ -223,9 +232,9 @@ def get_exception_handler_context(self):
223232
"""
224233
return {
225234
'view': self,
226-
'args': getattr(self, 'args', ()),
227-
'kwargs': getattr(self, 'kwargs', {}),
228-
'request': getattr(self, 'request', None)
235+
'args': self.args,
236+
'kwargs': self.kwargs,
237+
'request': self.request,
229238
}
230239

231240
def get_view_name(self):

tests/test_atomic_requests.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.conf.urls import url
44
from django.db import connection, connections, transaction
55
from django.http import Http404
6-
from django.test import TestCase, TransactionTestCase, override_settings
6+
from django.test import TestCase, override_settings
77

88
from rest_framework import status
99
from rest_framework.exceptions import APIException
@@ -39,12 +39,24 @@ def dispatch(self, *args, **kwargs):
3939
return super().dispatch(*args, **kwargs)
4040

4141
def get(self, request, *args, **kwargs):
42-
BasicModel.objects.all()
42+
list(BasicModel.objects.all())
43+
raise Http404
44+
45+
46+
class UrlDecoratedNonAtomicAPIExceptionView(APIView):
47+
def get(self, request, *args, **kwargs):
48+
list(BasicModel.objects.all())
4349
raise Http404
4450

4551

4652
urlpatterns = (
47-
url(r'^$', NonAtomicAPIExceptionView.as_view()),
53+
url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()),
54+
url(
55+
r'^url-decorated-non-atomic-exception$',
56+
transaction.non_atomic_requests(
57+
UrlDecoratedNonAtomicAPIExceptionView.as_view()
58+
),
59+
),
4860
)
4961

5062

@@ -94,8 +106,8 @@ def test_generic_exception_delegate_transaction_management(self):
94106
# 1 - begin savepoint
95107
# 2 - insert
96108
# 3 - release savepoint
97-
with transaction.atomic():
98-
self.assertRaises(Exception, self.view, request)
109+
with transaction.atomic(), self.assertRaises(Exception):
110+
self.view(request)
99111
assert not transaction.get_rollback()
100112
assert BasicModel.objects.count() == 1
101113

@@ -135,16 +147,25 @@ def test_api_exception_rollback_transaction(self):
135147
"'atomic' requires transactions and savepoints."
136148
)
137149
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
138-
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
150+
class NonAtomicDBTransactionAPIExceptionTests(TestCase):
139151
def setUp(self):
140152
connections.databases['default']['ATOMIC_REQUESTS'] = True
141153

142154
def tearDown(self):
143155
connections.databases['default']['ATOMIC_REQUESTS'] = False
144156

145157
def test_api_exception_rollback_transaction_non_atomic_view(self):
146-
response = self.client.get('/')
158+
response = self.client.get('/non-atomic-exception')
147159

148-
# without checking connection.in_atomic_block view raises 500
149-
# due attempt to rollback without transaction
150160
assert response.status_code == status.HTTP_404_NOT_FOUND
161+
assert not transaction.get_rollback()
162+
# Check we can still perform DB queries
163+
list(BasicModel.objects.all())
164+
165+
def test_api_exception_rollback_transaction_url_decorated_non_atomic_view(self):
166+
response = self.client.get('/url-decorated-non-atomic-exception')
167+
168+
assert response.status_code == status.HTTP_404_NOT_FOUND
169+
assert not transaction.get_rollback()
170+
# Check we can still perform DB queries
171+
list(BasicModel.objects.all())

0 commit comments

Comments
 (0)