Skip to content

Commit 86e82dd

Browse files
authored
Add support for shared_ptr<const T> in py::init() with smart_holder (#5731)
* Add overload to enable `.def(py::init(&rtrn_shcp))`. Also uncomment `.def(py::init(&rtrn_uqcp))` and `.def(py::init(&rtrn_udcp))`, which happen to work already (not sure what change in the past made those work). * Introduce `construct_from_shared_ptr()` helper for DRY-ness.
1 parent 365d41a commit 86e82dd

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

include/pybind11/detail/init.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,20 +246,38 @@ void construct(value_and_holder &v_h,
246246
v_h.type->init_instance(v_h.inst, &smhldr);
247247
}
248248

249-
template <typename Class, detail::enable_if_t<is_smart_holder<Holder<Class>>::value, int> = 0>
250-
void construct(value_and_holder &v_h, std::shared_ptr<Cpp<Class>> &&shd_ptr, bool need_alias) {
251-
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias);
249+
template <typename PtrType, typename Class>
250+
void construct_from_shared_ptr(value_and_holder &v_h,
251+
std::shared_ptr<PtrType> &&shd_ptr,
252+
bool need_alias) {
253+
static_assert(std::is_same<PtrType, Cpp<Class>>::value
254+
|| std::is_same<PtrType, const Cpp<Class>>::value,
255+
"Expected (const) Cpp<Class> as shared_ptr pointee");
252256
auto *ptr = shd_ptr.get();
253257
no_nullptr(ptr);
254258
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
255259
throw type_error("pybind11::init(): construction failed: returned std::shared_ptr pointee "
256260
"is not an alias instance");
257261
}
258-
auto smhldr = smart_holder::from_shared_ptr(shd_ptr);
259-
v_h.value_ptr() = ptr;
262+
// Cast to non-const if needed, consistent with internal design
263+
auto smhldr
264+
= smart_holder::from_shared_ptr(std::const_pointer_cast<Cpp<Class>>(std::move(shd_ptr)));
265+
v_h.value_ptr() = const_cast<Cpp<Class> *>(ptr);
260266
v_h.type->init_instance(v_h.inst, &smhldr);
261267
}
262268

269+
template <typename Class, detail::enable_if_t<is_smart_holder<Holder<Class>>::value, int> = 0>
270+
void construct(value_and_holder &v_h, std::shared_ptr<Cpp<Class>> &&shd_ptr, bool need_alias) {
271+
construct_from_shared_ptr<Cpp<Class>, Class>(v_h, std::move(shd_ptr), need_alias);
272+
}
273+
274+
template <typename Class, detail::enable_if_t<is_smart_holder<Holder<Class>>::value, int> = 0>
275+
void construct(value_and_holder &v_h,
276+
std::shared_ptr<const Cpp<Class>> &&shd_ptr,
277+
bool need_alias) {
278+
construct_from_shared_ptr<const Cpp<Class>, Class>(v_h, std::move(shd_ptr), need_alias);
279+
}
280+
263281
template <typename Class, detail::enable_if_t<is_smart_holder<Holder<Class>>::value, int> = 0>
264282
void construct(value_and_holder &v_h,
265283
std::shared_ptr<Alias<Class>> &&shd_ptr,

tests/test_class_sh_factory_constructors.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,23 @@ TEST_SUBMODULE(class_sh_factory_constructors, m) {
108108
.def("get_mtxt", get_mtxt<atyp_shmp>);
109109

110110
py::classh<atyp_shcp>(m, "atyp_shcp")
111-
// py::class_<atyp_shcp, std::shared_ptr<atyp_shcp>>(m, "atyp_shcp")
112-
// class_: ... must return a compatible ...
113-
// classh: ... cannot pass object of non-trivial type ...
114-
// .def(py::init(&rtrn_shcp))
111+
.def(py::init(&rtrn_shcp))
115112
.def("get_mtxt", get_mtxt<atyp_shcp>);
116113

117114
py::classh<atyp_uqmp>(m, "atyp_uqmp")
118115
.def(py::init(&rtrn_uqmp))
119116
.def("get_mtxt", get_mtxt<atyp_uqmp>);
120117

121118
py::classh<atyp_uqcp>(m, "atyp_uqcp")
122-
// class_: ... cannot pass object of non-trivial type ...
123-
// classh: ... cannot pass object of non-trivial type ...
124-
// .def(py::init(&rtrn_uqcp))
119+
.def(py::init(&rtrn_uqcp))
125120
.def("get_mtxt", get_mtxt<atyp_uqcp>);
126121

127122
py::classh<atyp_udmp>(m, "atyp_udmp")
128123
.def(py::init(&rtrn_udmp))
129124
.def("get_mtxt", get_mtxt<atyp_udmp>);
130125

131126
py::classh<atyp_udcp>(m, "atyp_udcp")
132-
// py::class_<atyp_udcp, std::unique_ptr<atyp_udcp, sddc>>(m, "atyp_udcp")
133-
// class_: ... must return a compatible ...
134-
// classh: ... cannot pass object of non-trivial type ...
135-
// .def(py::init(&rtrn_udcp))
127+
.def(py::init(&rtrn_udcp))
136128
.def("get_mtxt", get_mtxt<atyp_udcp>);
137129

138130
py::classh<with_alias, with_alias_alias>(m, "with_alias")

tests/test_class_sh_factory_constructors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def test_atyp_factories():
1313
# sert m.atyp_cptr().get_mtxt() == "Cptr"
1414
assert m.atyp_mptr().get_mtxt() == "Mptr"
1515
assert m.atyp_shmp().get_mtxt() == "Shmp"
16-
# sert m.atyp_shcp().get_mtxt() == "Shcp"
16+
assert m.atyp_shcp().get_mtxt() == "Shcp"
1717
assert m.atyp_uqmp().get_mtxt() == "Uqmp"
18-
# sert m.atyp_uqcp().get_mtxt() == "Uqcp"
18+
assert m.atyp_uqcp().get_mtxt() == "Uqcp"
1919
assert m.atyp_udmp().get_mtxt() == "Udmp"
20-
# sert m.atyp_udcp().get_mtxt() == "Udcp"
20+
assert m.atyp_udcp().get_mtxt() == "Udcp"
2121

2222

2323
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)