3
3
use crate :: {
4
4
attributes:: {
5
5
self , take_attributes, take_pyo3_options, CrateAttribute , ModuleAttribute , NameAttribute ,
6
+ SubmoduleAttribute ,
6
7
} ,
7
8
get_doc,
8
9
pyclass:: PyClassPyO3Option ,
@@ -27,6 +28,7 @@ pub struct PyModuleOptions {
27
28
krate : Option < CrateAttribute > ,
28
29
name : Option < syn:: Ident > ,
29
30
module : Option < ModuleAttribute > ,
31
+ is_submodule : bool ,
30
32
}
31
33
32
34
impl PyModuleOptions {
@@ -38,6 +40,7 @@ impl PyModuleOptions {
38
40
PyModulePyO3Option :: Name ( name) => options. set_name ( name. value . 0 ) ?,
39
41
PyModulePyO3Option :: Crate ( path) => options. set_crate ( path) ?,
40
42
PyModulePyO3Option :: Module ( module) => options. set_module ( module) ?,
43
+ PyModulePyO3Option :: Submodule ( submod) => options. set_submodule ( submod) ?,
41
44
}
42
45
}
43
46
@@ -73,9 +76,22 @@ impl PyModuleOptions {
73
76
self . module = Some ( name) ;
74
77
Ok ( ( ) )
75
78
}
79
+
80
+ fn set_submodule ( & mut self , submod : SubmoduleAttribute ) -> Result < ( ) > {
81
+ ensure_spanned ! (
82
+ !self . is_submodule,
83
+ submod. span( ) => "`submodule` may only be specified once"
84
+ ) ;
85
+
86
+ self . is_submodule = true ;
87
+ Ok ( ( ) )
88
+ }
76
89
}
77
90
78
- pub fn pymodule_module_impl ( mut module : syn:: ItemMod ) -> Result < TokenStream > {
91
+ pub fn pymodule_module_impl (
92
+ mut module : syn:: ItemMod ,
93
+ mut is_submodule : bool ,
94
+ ) -> Result < TokenStream > {
79
95
let syn:: ItemMod {
80
96
attrs,
81
97
vis,
@@ -100,6 +116,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
100
116
} else {
101
117
name. to_string ( )
102
118
} ;
119
+ is_submodule = is_submodule || options. is_submodule ;
103
120
104
121
let mut module_items = Vec :: new ( ) ;
105
122
let mut module_items_cfg_attrs = Vec :: new ( ) ;
@@ -297,7 +314,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
297
314
)
298
315
}
299
316
} } ;
300
- let initialization = module_initialization ( & name, ctx, module_def) ;
317
+ let initialization = module_initialization ( & name, ctx, module_def, is_submodule ) ;
301
318
Ok ( quote ! (
302
319
#( #attrs) *
303
320
#vis mod #ident {
@@ -331,7 +348,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
331
348
let vis = & function. vis ;
332
349
let doc = get_doc ( & function. attrs , None , ctx) ;
333
350
334
- let initialization = module_initialization ( & name, ctx, quote ! { MakeDef :: make_def( ) } ) ;
351
+ let initialization = module_initialization ( & name, ctx, quote ! { MakeDef :: make_def( ) } , false ) ;
335
352
336
353
// Module function called with optional Python<'_> marker as first arg, followed by the module.
337
354
let mut module_args = Vec :: new ( ) ;
@@ -396,28 +413,37 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
396
413
} )
397
414
}
398
415
399
- fn module_initialization ( name : & syn:: Ident , ctx : & Ctx , module_def : TokenStream ) -> TokenStream {
416
+ fn module_initialization (
417
+ name : & syn:: Ident ,
418
+ ctx : & Ctx ,
419
+ module_def : TokenStream ,
420
+ is_submodule : bool ,
421
+ ) -> TokenStream {
400
422
let Ctx { pyo3_path, .. } = ctx;
401
423
let pyinit_symbol = format ! ( "PyInit_{}" , name) ;
402
424
let name = name. to_string ( ) ;
403
425
let pyo3_name = LitCStr :: new ( CString :: new ( name) . unwrap ( ) , Span :: call_site ( ) , ctx) ;
404
426
405
- quote ! {
427
+ let mut result = quote ! {
406
428
#[ doc( hidden) ]
407
429
pub const __PYO3_NAME: & ' static :: std:: ffi:: CStr = #pyo3_name;
408
430
409
431
pub ( super ) struct MakeDef ;
410
432
#[ doc( hidden) ]
411
433
pub static _PYO3_DEF: #pyo3_path:: impl_:: pymodule:: ModuleDef = #module_def;
412
-
413
- /// This autogenerated function is called by the python interpreter when importing
414
- /// the module.
415
- #[ doc( hidden) ]
416
- #[ export_name = #pyinit_symbol]
417
- pub unsafe extern "C" fn __pyo3_init( ) -> * mut #pyo3_path:: ffi:: PyObject {
418
- #pyo3_path:: impl_:: trampoline:: module_init( |py| _PYO3_DEF. make_module( py) )
419
- }
434
+ } ;
435
+ if !is_submodule {
436
+ result. extend ( quote ! {
437
+ /// This autogenerated function is called by the python interpreter when importing
438
+ /// the module.
439
+ #[ doc( hidden) ]
440
+ #[ export_name = #pyinit_symbol]
441
+ pub unsafe extern "C" fn __pyo3_init( ) -> * mut #pyo3_path:: ffi:: PyObject {
442
+ #pyo3_path:: impl_:: trampoline:: module_init( |py| _PYO3_DEF. make_module( py) )
443
+ }
444
+ } ) ;
420
445
}
446
+ result
421
447
}
422
448
423
449
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
@@ -557,6 +583,7 @@ fn has_pyo3_module_declared<T: Parse>(
557
583
}
558
584
559
585
enum PyModulePyO3Option {
586
+ Submodule ( SubmoduleAttribute ) ,
560
587
Crate ( CrateAttribute ) ,
561
588
Name ( NameAttribute ) ,
562
589
Module ( ModuleAttribute ) ,
@@ -571,6 +598,8 @@ impl Parse for PyModulePyO3Option {
571
598
input. parse ( ) . map ( PyModulePyO3Option :: Crate )
572
599
} else if lookahead. peek ( attributes:: kw:: module) {
573
600
input. parse ( ) . map ( PyModulePyO3Option :: Module )
601
+ } else if lookahead. peek ( attributes:: kw:: submodule) {
602
+ input. parse ( ) . map ( PyModulePyO3Option :: Submodule )
574
603
} else {
575
604
Err ( lookahead. error ( ) )
576
605
}
0 commit comments