Skip to content

Commit 2ed1d70

Browse files
Merge #3184
3184: Additional tests for #3168 r=adamreichold a=adamreichold These were a part of tests `@lifthrasiir` was preparing for #3165, and I believe it's worthy to add them (any single of them fails in the current main branch). Co-authored-by: Kang Seonghoon <[email protected]>
2 parents 32c335e + e884327 commit 2ed1d70

File tree

1 file changed

+199
-34
lines changed

1 file changed

+199
-34
lines changed

tests/test_gc.rs

Lines changed: 199 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError;
44
use pyo3::class::PyVisit;
55
use pyo3::prelude::*;
66
use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto};
7+
use std::cell::Cell;
78
use std::sync::atomic::{AtomicBool, Ordering};
89
use std::sync::Arc;
910

@@ -248,22 +249,10 @@ impl TraversableClass {
248249
}
249250
}
250251

251-
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
252-
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
253-
}
254-
255252
#[test]
256253
fn gc_during_borrow() {
257254
Python::with_gil(|py| {
258255
unsafe {
259-
// declare a dummy visitor function
260-
extern "C" fn novisit(
261-
_object: *mut pyo3::ffi::PyObject,
262-
_arg: *mut core::ffi::c_void,
263-
) -> std::os::raw::c_int {
264-
0
265-
}
266-
267256
// get the traverse function
268257
let ty = py.get_type::<TraversableClass>().as_type_ptr();
269258
let traverse = get_type_traverse(ty).unwrap();
@@ -290,18 +279,18 @@ fn gc_during_borrow() {
290279
}
291280

292281
#[pyclass]
293-
struct PanickyTraverse {
282+
struct PartialTraverse {
294283
member: PyObject,
295284
}
296285

297-
impl PanickyTraverse {
286+
impl PartialTraverse {
298287
fn new(py: Python<'_>) -> Self {
299288
Self { member: py.None() }
300289
}
301290
}
302291

303292
#[pymethods]
304-
impl PanickyTraverse {
293+
impl PartialTraverse {
305294
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
306295
visit.call(&self.member)?;
307296
// In the test, we expect this to never be hit
@@ -310,29 +299,53 @@ impl PanickyTraverse {
310299
}
311300

312301
#[test]
313-
fn traverse_error() {
302+
fn traverse_partial() {
314303
Python::with_gil(|py| unsafe {
315-
// declare a visitor function which errors (returns nonzero code)
316-
extern "C" fn visit_error(
317-
_object: *mut pyo3::ffi::PyObject,
318-
_arg: *mut core::ffi::c_void,
319-
) -> std::os::raw::c_int {
320-
-1
321-
}
322-
323304
// get the traverse function
324-
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
305+
let ty = py.get_type::<PartialTraverse>().as_type_ptr();
325306
let traverse = get_type_traverse(ty).unwrap();
326307

327308
// confirm that traversing errors
328-
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
309+
let obj = Py::new(py, PartialTraverse::new(py)).unwrap();
329310
assert_eq!(
330311
traverse(obj.as_ptr(), visit_error, std::ptr::null_mut()),
331312
-1
332313
);
333314
})
334315
}
335316

317+
#[pyclass]
318+
struct PanickyTraverse {
319+
member: PyObject,
320+
}
321+
322+
impl PanickyTraverse {
323+
fn new(py: Python<'_>) -> Self {
324+
Self { member: py.None() }
325+
}
326+
}
327+
328+
#[pymethods]
329+
impl PanickyTraverse {
330+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
331+
visit.call(&self.member)?;
332+
panic!("at the disco");
333+
}
334+
}
335+
336+
#[test]
337+
fn traverse_panic() {
338+
Python::with_gil(|py| unsafe {
339+
// get the traverse function
340+
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
341+
let traverse = get_type_traverse(ty).unwrap();
342+
343+
// confirm that traversing errors
344+
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
345+
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
346+
})
347+
}
348+
336349
#[pyclass]
337350
struct TriesGILInTraverse {}
338351

@@ -346,14 +359,6 @@ impl TriesGILInTraverse {
346359
#[test]
347360
fn tries_gil_in_traverse() {
348361
Python::with_gil(|py| unsafe {
349-
// declare a visitor function which errors (returns nonzero code)
350-
extern "C" fn novisit(
351-
_object: *mut pyo3::ffi::PyObject,
352-
_arg: *mut core::ffi::c_void,
353-
) -> std::os::raw::c_int {
354-
0
355-
}
356-
357362
// get the traverse function
358363
let ty = py.get_type::<TriesGILInTraverse>().as_type_ptr();
359364
let traverse = get_type_traverse(ty).unwrap();
@@ -363,3 +368,163 @@ fn tries_gil_in_traverse() {
363368
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
364369
})
365370
}
371+
372+
#[pyclass]
373+
struct HijackedTraverse {
374+
traversed: Cell<bool>,
375+
hijacked: Cell<bool>,
376+
}
377+
378+
impl HijackedTraverse {
379+
fn new() -> Self {
380+
Self {
381+
traversed: Cell::new(false),
382+
hijacked: Cell::new(false),
383+
}
384+
}
385+
386+
fn traversed_and_hijacked(&self) -> (bool, bool) {
387+
(self.traversed.get(), self.hijacked.get())
388+
}
389+
}
390+
391+
#[pymethods]
392+
impl HijackedTraverse {
393+
#[allow(clippy::unnecessary_wraps)]
394+
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
395+
self.traversed.set(true);
396+
Ok(())
397+
}
398+
}
399+
400+
trait Traversable {
401+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>;
402+
}
403+
404+
impl<'a> Traversable for PyRef<'a, HijackedTraverse> {
405+
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
406+
self.hijacked.set(true);
407+
Ok(())
408+
}
409+
}
410+
411+
#[test]
412+
fn traverse_cannot_be_hijacked() {
413+
Python::with_gil(|py| unsafe {
414+
// get the traverse function
415+
let ty = py.get_type::<HijackedTraverse>().as_type_ptr();
416+
let traverse = get_type_traverse(ty).unwrap();
417+
418+
let cell = PyCell::new(py, HijackedTraverse::new()).unwrap();
419+
let obj = cell.to_object(py);
420+
assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false));
421+
traverse(obj.as_ptr(), novisit, std::ptr::null_mut());
422+
assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false));
423+
})
424+
}
425+
426+
#[allow(dead_code)]
427+
#[pyclass]
428+
struct DropDuringTraversal {
429+
cycle: Cell<Option<Py<Self>>>,
430+
dropped: TestDropCall,
431+
}
432+
433+
#[pymethods]
434+
impl DropDuringTraversal {
435+
#[allow(clippy::unnecessary_wraps)]
436+
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
437+
self.cycle.take();
438+
Ok(())
439+
}
440+
441+
fn __clear__(&mut self) {
442+
self.cycle.take();
443+
}
444+
}
445+
446+
#[test]
447+
fn drop_during_traversal_with_gil() {
448+
let drop_called = Arc::new(AtomicBool::new(false));
449+
450+
Python::with_gil(|py| {
451+
let inst = Py::new(
452+
py,
453+
DropDuringTraversal {
454+
cycle: Cell::new(None),
455+
dropped: TestDropCall {
456+
drop_called: Arc::clone(&drop_called),
457+
},
458+
},
459+
)
460+
.unwrap();
461+
462+
inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));
463+
464+
drop(inst);
465+
});
466+
467+
// due to the internal GC mechanism, we may need multiple
468+
// (but not too many) collections to get `inst` actually dropped.
469+
for _ in 0..10 {
470+
Python::with_gil(|py| {
471+
py.run("import gc; gc.collect()", None, None).unwrap();
472+
});
473+
}
474+
assert!(drop_called.load(Ordering::Relaxed));
475+
}
476+
477+
#[test]
478+
fn drop_during_traversal_without_gil() {
479+
let drop_called = Arc::new(AtomicBool::new(false));
480+
481+
let inst = Python::with_gil(|py| {
482+
let inst = Py::new(
483+
py,
484+
DropDuringTraversal {
485+
cycle: Cell::new(None),
486+
dropped: TestDropCall {
487+
drop_called: Arc::clone(&drop_called),
488+
},
489+
},
490+
)
491+
.unwrap();
492+
493+
inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));
494+
495+
inst
496+
});
497+
498+
drop(inst);
499+
500+
// due to the internal GC mechanism, we may need multiple
501+
// (but not too many) collections to get `inst` actually dropped.
502+
for _ in 0..10 {
503+
Python::with_gil(|py| {
504+
py.run("import gc; gc.collect()", None, None).unwrap();
505+
});
506+
}
507+
assert!(drop_called.load(Ordering::Relaxed));
508+
}
509+
510+
// Manual traversal utilities
511+
512+
unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
513+
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
514+
}
515+
516+
// a dummy visitor function
517+
extern "C" fn novisit(
518+
_object: *mut pyo3::ffi::PyObject,
519+
_arg: *mut core::ffi::c_void,
520+
) -> std::os::raw::c_int {
521+
0
522+
}
523+
524+
// a visitor function which errors (returns nonzero code)
525+
extern "C" fn visit_error(
526+
_object: *mut pyo3::ffi::PyObject,
527+
_arg: *mut core::ffi::c_void,
528+
) -> std::os::raw::c_int {
529+
-1
530+
}

0 commit comments

Comments
 (0)