Skip to content

Commit 8897012

Browse files
bpo-43977: Properly update the tp_flags of existing subclasses when their parents are registered (GH-26864)
(cherry picked from commit ca2009d72a52a98bf43aafa9ad270a4fcfabfc89) Co-authored-by: Brandt Bucher <[email protected]>
1 parent 8bec9fb commit 8897012

File tree

4 files changed

+129
-34
lines changed

4 files changed

+129
-34
lines changed

Doc/library/dis.rst

+8-5
Original file line numberDiff line numberDiff line change
@@ -770,17 +770,20 @@ iterations of the loop.
770770

771771
.. opcode:: MATCH_MAPPING
772772

773-
If TOS is an instance of :class:`collections.abc.Mapping`, push ``True`` onto
774-
the stack. Otherwise, push ``False``.
773+
If TOS is an instance of :class:`collections.abc.Mapping` (or, more technically: if
774+
it has the :const:`Py_TPFLAGS_MAPPING` flag set in its
775+
:c:member:`~PyTypeObject.tp_flags`), push ``True`` onto the stack. Otherwise, push
776+
``False``.
775777

776778
.. versionadded:: 3.10
777779

778780

779781
.. opcode:: MATCH_SEQUENCE
780782

781-
If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an
782-
instance of :class:`str`/:class:`bytes`/:class:`bytearray`, push ``True``
783-
onto the stack. Otherwise, push ``False``.
783+
If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an instance
784+
of :class:`str`/:class:`bytes`/:class:`bytearray` (or, more technically: if it has
785+
the :const:`Py_TPFLAGS_SEQUENCE` flag set in its :c:member:`~PyTypeObject.tp_flags`),
786+
push ``True`` onto the stack. Otherwise, push ``False``.
784787

785788
.. versionadded:: 3.10
786789

Lib/test/test_patma.py

+87-23
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,43 @@ def test_refleaks(self):
2424

2525
class TestInheritance(unittest.TestCase):
2626

27-
def test_multiple_inheritance(self):
27+
@staticmethod
28+
def check_sequence_then_mapping(x):
29+
match x:
30+
case [*_]:
31+
return "seq"
32+
case {}:
33+
return "map"
34+
35+
@staticmethod
36+
def check_mapping_then_sequence(x):
37+
match x:
38+
case {}:
39+
return "map"
40+
case [*_]:
41+
return "seq"
42+
43+
def test_multiple_inheritance_mapping(self):
44+
class C:
45+
pass
46+
class M1(collections.UserDict, collections.abc.Sequence):
47+
pass
48+
class M2(C, collections.UserDict, collections.abc.Sequence):
49+
pass
50+
class M3(collections.UserDict, C, list):
51+
pass
52+
class M4(dict, collections.abc.Sequence, C):
53+
pass
54+
self.assertEqual(self.check_sequence_then_mapping(M1()), "map")
55+
self.assertEqual(self.check_sequence_then_mapping(M2()), "map")
56+
self.assertEqual(self.check_sequence_then_mapping(M3()), "map")
57+
self.assertEqual(self.check_sequence_then_mapping(M4()), "map")
58+
self.assertEqual(self.check_mapping_then_sequence(M1()), "map")
59+
self.assertEqual(self.check_mapping_then_sequence(M2()), "map")
60+
self.assertEqual(self.check_mapping_then_sequence(M3()), "map")
61+
self.assertEqual(self.check_mapping_then_sequence(M4()), "map")
62+
63+
def test_multiple_inheritance_sequence(self):
2864
class C:
2965
pass
3066
class S1(collections.UserList, collections.abc.Mapping):
@@ -35,32 +71,60 @@ class S3(list, C, collections.abc.Mapping):
3571
pass
3672
class S4(collections.UserList, dict, C):
3773
pass
38-
class M1(collections.UserDict, collections.abc.Sequence):
74+
self.assertEqual(self.check_sequence_then_mapping(S1()), "seq")
75+
self.assertEqual(self.check_sequence_then_mapping(S2()), "seq")
76+
self.assertEqual(self.check_sequence_then_mapping(S3()), "seq")
77+
self.assertEqual(self.check_sequence_then_mapping(S4()), "seq")
78+
self.assertEqual(self.check_mapping_then_sequence(S1()), "seq")
79+
self.assertEqual(self.check_mapping_then_sequence(S2()), "seq")
80+
self.assertEqual(self.check_mapping_then_sequence(S3()), "seq")
81+
self.assertEqual(self.check_mapping_then_sequence(S4()), "seq")
82+
83+
def test_late_registration_mapping(self):
84+
class Parent:
3985
pass
40-
class M2(C, collections.UserDict, collections.abc.Sequence):
86+
class ChildPre(Parent):
4187
pass
42-
class M3(collections.UserDict, C, list):
88+
class GrandchildPre(ChildPre):
4389
pass
44-
class M4(dict, collections.abc.Sequence, C):
90+
collections.abc.Mapping.register(Parent)
91+
class ChildPost(Parent):
4592
pass
46-
def f(x):
47-
match x:
48-
case []:
49-
return "seq"
50-
case {}:
51-
return "map"
52-
def g(x):
53-
match x:
54-
case {}:
55-
return "map"
56-
case []:
57-
return "seq"
58-
for Seq in (S1, S2, S3, S4):
59-
self.assertEqual(f(Seq()), "seq")
60-
self.assertEqual(g(Seq()), "seq")
61-
for Map in (M1, M2, M3, M4):
62-
self.assertEqual(f(Map()), "map")
63-
self.assertEqual(g(Map()), "map")
93+
class GrandchildPost(ChildPost):
94+
pass
95+
self.assertEqual(self.check_sequence_then_mapping(Parent()), "map")
96+
self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "map")
97+
self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "map")
98+
self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "map")
99+
self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "map")
100+
self.assertEqual(self.check_mapping_then_sequence(Parent()), "map")
101+
self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "map")
102+
self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "map")
103+
self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map")
104+
self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map")
105+
106+
def test_late_registration_sequence(self):
107+
class Parent:
108+
pass
109+
class ChildPre(Parent):
110+
pass
111+
class GrandchildPre(ChildPre):
112+
pass
113+
collections.abc.Sequence.register(Parent)
114+
class ChildPost(Parent):
115+
pass
116+
class GrandchildPost(ChildPost):
117+
pass
118+
self.assertEqual(self.check_sequence_then_mapping(Parent()), "seq")
119+
self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "seq")
120+
self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "seq")
121+
self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "seq")
122+
self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "seq")
123+
self.assertEqual(self.check_mapping_then_sequence(Parent()), "seq")
124+
self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "seq")
125+
self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "seq")
126+
self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "seq")
127+
self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "seq")
64128

65129

66130
class TestPatma(unittest.TestCase):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Set the proper :const:`Py_TPFLAGS_MAPPING` and :const:`Py_TPFLAGS_SEQUENCE`
2+
flags for subclasses created before a parent has been registered as a
3+
:class:`collections.abc.Mapping` or :class:`collections.abc.Sequence`.

Modules/_abc.c

+31-6
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,32 @@ _abc__abc_init(PyObject *module, PyObject *self)
481481
Py_RETURN_NONE;
482482
}
483483

484+
static void
485+
set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
486+
{
487+
assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
488+
if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
489+
(child->tp_flags & COLLECTION_FLAGS) == flag)
490+
{
491+
return;
492+
}
493+
child->tp_flags &= ~COLLECTION_FLAGS;
494+
child->tp_flags |= flag;
495+
PyObject *grandchildren = child->tp_subclasses;
496+
if (grandchildren == NULL) {
497+
return;
498+
}
499+
assert(PyDict_CheckExact(grandchildren));
500+
Py_ssize_t i = 0;
501+
while (PyDict_Next(grandchildren, &i, NULL, &grandchildren)) {
502+
assert(PyWeakref_CheckRef(grandchildren));
503+
PyObject *grandchild = PyWeakref_GET_OBJECT(grandchildren);
504+
if (PyType_Check(grandchild)) {
505+
set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
506+
}
507+
}
508+
}
509+
484510
/*[clinic input]
485511
_abc._abc_register
486512
@@ -532,12 +558,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
532558
get_abc_state(module)->abc_invalidation_counter++;
533559

534560
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
535-
if (PyType_Check(self) &&
536-
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
537-
((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
538-
{
539-
((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
540-
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
561+
if (PyType_Check(self)) {
562+
unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
563+
if (collection_flag) {
564+
set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
565+
}
541566
}
542567
Py_INCREF(subclass);
543568
return subclass;

0 commit comments

Comments
 (0)