Skip to content

Commit f4c04ca

Browse files
committed
refs PyO3#4286 -- allow setting submodule on declarative pymodules
1 parent 8f7450e commit f4c04ca

File tree

5 files changed

+44
-19
lines changed

5 files changed

+44
-19
lines changed

guide/src/module.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pycl
154154
For nested modules, the name of the parent module is automatically added.
155155
In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested
156156
but the `Ext` class will have for `module` the default `builtins` because it not nested.
157+
158+
You can provide the `submodule` argument to `pymodule()` for modules that are not top-level modules.
157159
```rust
158160
# mod declarative_module_module_attr_test {
159161
use pyo3::prelude::*;
@@ -168,7 +170,7 @@ mod my_extension {
168170
#[pymodule_export]
169171
use super::Ext;
170172

171-
#[pymodule]
173+
#[pymodule(submodule)]
172174
mod submodule {
173175
use super::*;
174176
// This is a submodule

newsfragments/4301.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
allow setting `submodule` on declarative `#[pymodule]`s

pyo3-macros-backend/src/module.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ impl PyModuleOptions {
7575
}
7676
}
7777

78-
pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
78+
pub fn pymodule_module_impl(mut module: syn::ItemMod, is_submodule: bool) -> Result<TokenStream> {
7979
let syn::ItemMod {
8080
attrs,
8181
vis,
@@ -286,7 +286,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
286286
}
287287
}
288288

289-
let initialization = module_initialization(&name, ctx);
289+
let initialization = module_initialization(&name, ctx, is_submodule);
290290
Ok(quote!(
291291
#(#attrs)*
292292
#vis mod #ident {
@@ -335,7 +335,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
335335
let vis = &function.vis;
336336
let doc = get_doc(&function.attrs, None, ctx);
337337

338-
let initialization = module_initialization(&name, ctx);
338+
let initialization = module_initialization(&name, ctx, false);
339339

340340
// Module function called with optional Python<'_> marker as first arg, followed by the module.
341341
let mut module_args = Vec::new();
@@ -400,28 +400,34 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
400400
})
401401
}
402402

403-
fn module_initialization(name: &syn::Ident, ctx: &Ctx) -> TokenStream {
403+
fn module_initialization(name: &syn::Ident, ctx: &Ctx, is_submodule: bool) -> TokenStream {
404404
let Ctx { pyo3_path, .. } = ctx;
405405
let pyinit_symbol = format!("PyInit_{}", name);
406406
let name = name.to_string();
407407
let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx);
408408

409-
quote! {
409+
let mut base = quote! {
410410
#[doc(hidden)]
411411
pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
412412

413413
pub(super) struct MakeDef;
414414
#[doc(hidden)]
415415
pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = MakeDef::make_def();
416-
417-
/// This autogenerated function is called by the python interpreter when importing
418-
/// the module.
419-
#[doc(hidden)]
420-
#[export_name = #pyinit_symbol]
421-
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
422-
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
423-
}
416+
};
417+
if !is_submodule {
418+
base = quote! {
419+
#base
420+
421+
/// This autogenerated function is called by the python interpreter when importing
422+
/// the module.
423+
#[doc(hidden)]
424+
#[export_name = #pyinit_symbol]
425+
pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
426+
#pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py))
427+
}
428+
};
424429
}
430+
base
425431
}
426432

427433
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`

pyo3-macros/src/lib.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
55
use proc_macro::TokenStream;
6-
use proc_macro2::TokenStream as TokenStream2;
6+
use proc_macro2::{Span, TokenStream as TokenStream2};
77
use pyo3_macros_backend::{
88
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
99
pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType,
@@ -35,10 +35,26 @@ use syn::{parse::Nothing, parse_macro_input, Item};
3535
/// [1]: https://pyo3.rs/latest/module.html
3636
#[proc_macro_attribute]
3737
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
38-
parse_macro_input!(args as Nothing);
3938
match parse_macro_input!(input as Item) {
40-
Item::Mod(module) => pymodule_module_impl(module),
41-
Item::Fn(function) => pymodule_function_impl(function),
39+
Item::Mod(module) => {
40+
let is_submodule = match parse_macro_input!(args as Option<syn::Ident>) {
41+
Some(i) if i == "submodule" => true,
42+
Some(_) => {
43+
return syn::Error::new(
44+
Span::call_site(),
45+
"#[pymodule] only accepts submodule as an argument",
46+
)
47+
.into_compile_error()
48+
.into();
49+
}
50+
None => false,
51+
};
52+
pymodule_module_impl(module, is_submodule)
53+
}
54+
Item::Fn(function) => {
55+
parse_macro_input!(args as Nothing);
56+
pymodule_function_impl(function)
57+
}
4258
unsupported => Err(syn::Error::new_spanned(
4359
unsupported,
4460
"#[pymodule] only supports modules and functions.",

tests/test_declarative_module.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ mod declarative_module {
108108
}
109109
}
110110

111-
#[pymodule]
111+
#[pymodule(submodule)]
112112
#[pyo3(module = "custom_root")]
113113
mod inner_custom_root {
114114
use super::*;

0 commit comments

Comments
 (0)