@@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError;
4
4
use pyo3:: class:: PyVisit ;
5
5
use pyo3:: prelude:: * ;
6
6
use pyo3:: { py_run, AsPyPointer , PyCell , PyTryInto } ;
7
+ use std:: cell:: Cell ;
7
8
use std:: sync:: atomic:: { AtomicBool , Ordering } ;
8
9
use std:: sync:: Arc ;
9
10
@@ -248,22 +249,10 @@ impl TraversableClass {
248
249
}
249
250
}
250
251
251
- unsafe fn get_type_traverse ( tp : * mut pyo3:: ffi:: PyTypeObject ) -> Option < pyo3:: ffi:: traverseproc > {
252
- std:: mem:: transmute ( pyo3:: ffi:: PyType_GetSlot ( tp, pyo3:: ffi:: Py_tp_traverse ) )
253
- }
254
-
255
252
#[ test]
256
253
fn gc_during_borrow ( ) {
257
254
Python :: with_gil ( |py| {
258
255
unsafe {
259
- // declare a dummy visitor function
260
- extern "C" fn novisit (
261
- _object : * mut pyo3:: ffi:: PyObject ,
262
- _arg : * mut core:: ffi:: c_void ,
263
- ) -> std:: os:: raw:: c_int {
264
- 0
265
- }
266
-
267
256
// get the traverse function
268
257
let ty = py. get_type :: < TraversableClass > ( ) . as_type_ptr ( ) ;
269
258
let traverse = get_type_traverse ( ty) . unwrap ( ) ;
@@ -290,18 +279,18 @@ fn gc_during_borrow() {
290
279
}
291
280
292
281
#[ pyclass]
293
- struct PanickyTraverse {
282
+ struct PartialTraverse {
294
283
member : PyObject ,
295
284
}
296
285
297
- impl PanickyTraverse {
286
+ impl PartialTraverse {
298
287
fn new ( py : Python < ' _ > ) -> Self {
299
288
Self { member : py. None ( ) }
300
289
}
301
290
}
302
291
303
292
#[ pymethods]
304
- impl PanickyTraverse {
293
+ impl PartialTraverse {
305
294
fn __traverse__ ( & self , visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > {
306
295
visit. call ( & self . member ) ?;
307
296
// In the test, we expect this to never be hit
@@ -310,29 +299,53 @@ impl PanickyTraverse {
310
299
}
311
300
312
301
#[ test]
313
- fn traverse_error ( ) {
302
+ fn traverse_partial ( ) {
314
303
Python :: with_gil ( |py| unsafe {
315
- // declare a visitor function which errors (returns nonzero code)
316
- extern "C" fn visit_error (
317
- _object : * mut pyo3:: ffi:: PyObject ,
318
- _arg : * mut core:: ffi:: c_void ,
319
- ) -> std:: os:: raw:: c_int {
320
- -1
321
- }
322
-
323
304
// get the traverse function
324
- let ty = py. get_type :: < PanickyTraverse > ( ) . as_type_ptr ( ) ;
305
+ let ty = py. get_type :: < PartialTraverse > ( ) . as_type_ptr ( ) ;
325
306
let traverse = get_type_traverse ( ty) . unwrap ( ) ;
326
307
327
308
// confirm that traversing errors
328
- let obj = Py :: new ( py, PanickyTraverse :: new ( py) ) . unwrap ( ) ;
309
+ let obj = Py :: new ( py, PartialTraverse :: new ( py) ) . unwrap ( ) ;
329
310
assert_eq ! (
330
311
traverse( obj. as_ptr( ) , visit_error, std:: ptr:: null_mut( ) ) ,
331
312
-1
332
313
) ;
333
314
} )
334
315
}
335
316
317
+ #[ pyclass]
318
+ struct PanickyTraverse {
319
+ member : PyObject ,
320
+ }
321
+
322
+ impl PanickyTraverse {
323
+ fn new ( py : Python < ' _ > ) -> Self {
324
+ Self { member : py. None ( ) }
325
+ }
326
+ }
327
+
328
+ #[ pymethods]
329
+ impl PanickyTraverse {
330
+ fn __traverse__ ( & self , visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > {
331
+ visit. call ( & self . member ) ?;
332
+ panic ! ( "at the disco" ) ;
333
+ }
334
+ }
335
+
336
+ #[ test]
337
+ fn traverse_panic ( ) {
338
+ Python :: with_gil ( |py| unsafe {
339
+ // get the traverse function
340
+ let ty = py. get_type :: < PanickyTraverse > ( ) . as_type_ptr ( ) ;
341
+ let traverse = get_type_traverse ( ty) . unwrap ( ) ;
342
+
343
+ // confirm that traversing errors
344
+ let obj = Py :: new ( py, PanickyTraverse :: new ( py) ) . unwrap ( ) ;
345
+ assert_eq ! ( traverse( obj. as_ptr( ) , novisit, std:: ptr:: null_mut( ) ) , -1 ) ;
346
+ } )
347
+ }
348
+
336
349
#[ pyclass]
337
350
struct TriesGILInTraverse { }
338
351
@@ -346,14 +359,6 @@ impl TriesGILInTraverse {
346
359
#[ test]
347
360
fn tries_gil_in_traverse ( ) {
348
361
Python :: with_gil ( |py| unsafe {
349
- // declare a visitor function which errors (returns nonzero code)
350
- extern "C" fn novisit (
351
- _object : * mut pyo3:: ffi:: PyObject ,
352
- _arg : * mut core:: ffi:: c_void ,
353
- ) -> std:: os:: raw:: c_int {
354
- 0
355
- }
356
-
357
362
// get the traverse function
358
363
let ty = py. get_type :: < TriesGILInTraverse > ( ) . as_type_ptr ( ) ;
359
364
let traverse = get_type_traverse ( ty) . unwrap ( ) ;
@@ -363,3 +368,163 @@ fn tries_gil_in_traverse() {
363
368
assert_eq ! ( traverse( obj. as_ptr( ) , novisit, std:: ptr:: null_mut( ) ) , -1 ) ;
364
369
} )
365
370
}
371
+
372
+ #[ pyclass]
373
+ struct HijackedTraverse {
374
+ traversed : Cell < bool > ,
375
+ hijacked : Cell < bool > ,
376
+ }
377
+
378
+ impl HijackedTraverse {
379
+ fn new ( ) -> Self {
380
+ Self {
381
+ traversed : Cell :: new ( false ) ,
382
+ hijacked : Cell :: new ( false ) ,
383
+ }
384
+ }
385
+
386
+ fn traversed_and_hijacked ( & self ) -> ( bool , bool ) {
387
+ ( self . traversed . get ( ) , self . hijacked . get ( ) )
388
+ }
389
+ }
390
+
391
+ #[ pymethods]
392
+ impl HijackedTraverse {
393
+ #[ allow( clippy:: unnecessary_wraps) ]
394
+ fn __traverse__ ( & self , _visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > {
395
+ self . traversed . set ( true ) ;
396
+ Ok ( ( ) )
397
+ }
398
+ }
399
+
400
+ trait Traversable {
401
+ fn __traverse__ ( & self , visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > ;
402
+ }
403
+
404
+ impl < ' a > Traversable for PyRef < ' a , HijackedTraverse > {
405
+ fn __traverse__ ( & self , _visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > {
406
+ self . hijacked . set ( true ) ;
407
+ Ok ( ( ) )
408
+ }
409
+ }
410
+
411
+ #[ test]
412
+ fn traverse_cannot_be_hijacked ( ) {
413
+ Python :: with_gil ( |py| unsafe {
414
+ // get the traverse function
415
+ let ty = py. get_type :: < HijackedTraverse > ( ) . as_type_ptr ( ) ;
416
+ let traverse = get_type_traverse ( ty) . unwrap ( ) ;
417
+
418
+ let cell = PyCell :: new ( py, HijackedTraverse :: new ( ) ) . unwrap ( ) ;
419
+ let obj = cell. to_object ( py) ;
420
+ assert_eq ! ( cell. borrow( ) . traversed_and_hijacked( ) , ( false , false ) ) ;
421
+ traverse ( obj. as_ptr ( ) , novisit, std:: ptr:: null_mut ( ) ) ;
422
+ assert_eq ! ( cell. borrow( ) . traversed_and_hijacked( ) , ( true , false ) ) ;
423
+ } )
424
+ }
425
+
426
+ #[ allow( dead_code) ]
427
+ #[ pyclass]
428
+ struct DropDuringTraversal {
429
+ cycle : Cell < Option < Py < Self > > > ,
430
+ dropped : TestDropCall ,
431
+ }
432
+
433
+ #[ pymethods]
434
+ impl DropDuringTraversal {
435
+ #[ allow( clippy:: unnecessary_wraps) ]
436
+ fn __traverse__ ( & self , _visit : PyVisit < ' _ > ) -> Result < ( ) , PyTraverseError > {
437
+ self . cycle . take ( ) ;
438
+ Ok ( ( ) )
439
+ }
440
+
441
+ fn __clear__ ( & mut self ) {
442
+ self . cycle . take ( ) ;
443
+ }
444
+ }
445
+
446
+ #[ test]
447
+ fn drop_during_traversal_with_gil ( ) {
448
+ let drop_called = Arc :: new ( AtomicBool :: new ( false ) ) ;
449
+
450
+ Python :: with_gil ( |py| {
451
+ let inst = Py :: new (
452
+ py,
453
+ DropDuringTraversal {
454
+ cycle : Cell :: new ( None ) ,
455
+ dropped : TestDropCall {
456
+ drop_called : Arc :: clone ( & drop_called) ,
457
+ } ,
458
+ } ,
459
+ )
460
+ . unwrap ( ) ;
461
+
462
+ inst. borrow_mut ( py) . cycle . set ( Some ( inst. clone_ref ( py) ) ) ;
463
+
464
+ drop ( inst) ;
465
+ } ) ;
466
+
467
+ // due to the internal GC mechanism, we may need multiple
468
+ // (but not too many) collections to get `inst` actually dropped.
469
+ for _ in 0 ..10 {
470
+ Python :: with_gil ( |py| {
471
+ py. run ( "import gc; gc.collect()" , None , None ) . unwrap ( ) ;
472
+ } ) ;
473
+ }
474
+ assert ! ( drop_called. load( Ordering :: Relaxed ) ) ;
475
+ }
476
+
477
+ #[ test]
478
+ fn drop_during_traversal_without_gil ( ) {
479
+ let drop_called = Arc :: new ( AtomicBool :: new ( false ) ) ;
480
+
481
+ let inst = Python :: with_gil ( |py| {
482
+ let inst = Py :: new (
483
+ py,
484
+ DropDuringTraversal {
485
+ cycle : Cell :: new ( None ) ,
486
+ dropped : TestDropCall {
487
+ drop_called : Arc :: clone ( & drop_called) ,
488
+ } ,
489
+ } ,
490
+ )
491
+ . unwrap ( ) ;
492
+
493
+ inst. borrow_mut ( py) . cycle . set ( Some ( inst. clone_ref ( py) ) ) ;
494
+
495
+ inst
496
+ } ) ;
497
+
498
+ drop ( inst) ;
499
+
500
+ // due to the internal GC mechanism, we may need multiple
501
+ // (but not too many) collections to get `inst` actually dropped.
502
+ for _ in 0 ..10 {
503
+ Python :: with_gil ( |py| {
504
+ py. run ( "import gc; gc.collect()" , None , None ) . unwrap ( ) ;
505
+ } ) ;
506
+ }
507
+ assert ! ( drop_called. load( Ordering :: Relaxed ) ) ;
508
+ }
509
+
510
+ // Manual traversal utilities
511
+
512
+ unsafe fn get_type_traverse ( tp : * mut pyo3:: ffi:: PyTypeObject ) -> Option < pyo3:: ffi:: traverseproc > {
513
+ std:: mem:: transmute ( pyo3:: ffi:: PyType_GetSlot ( tp, pyo3:: ffi:: Py_tp_traverse ) )
514
+ }
515
+
516
+ // a dummy visitor function
517
+ extern "C" fn novisit (
518
+ _object : * mut pyo3:: ffi:: PyObject ,
519
+ _arg : * mut core:: ffi:: c_void ,
520
+ ) -> std:: os:: raw:: c_int {
521
+ 0
522
+ }
523
+
524
+ // a visitor function which errors (returns nonzero code)
525
+ extern "C" fn visit_error (
526
+ _object : * mut pyo3:: ffi:: PyObject ,
527
+ _arg : * mut core:: ffi:: c_void ,
528
+ ) -> std:: os:: raw:: c_int {
529
+ -1
530
+ }
0 commit comments