diff --git a/pyo3-derive-backend/src/py_class.rs b/pyo3-derive-backend/src/py_class.rs index 84c46641a84..1409cad05e9 100644 --- a/pyo3-derive-backend/src/py_class.rs +++ b/pyo3-derive-backend/src/py_class.rs @@ -3,12 +3,21 @@ use method::{FnArg, FnSpec, FnType}; use proc_macro2::{Span, TokenStream}; use py_method::{impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter}; -use std::collections::HashMap; +use quote::ToTokens; use syn; use utils; +#[derive(Default, Debug)] +struct PyClassAttributes { + flags: Vec, + freelist: Option, + name: Option, + base: Option, + variants: Option>, +} + pub fn build_py_class(class: &mut syn::ItemStruct, attr: &Vec) -> TokenStream { - let (params, flags, base) = parse_attribute(attr); + let attrs = parse_attribute(attr); let doc = utils::get_doc(&class.attrs, true); let mut descriptors = Vec::new(); @@ -23,7 +32,13 @@ pub fn build_py_class(class: &mut syn::ItemStruct, attr: &Vec) -> Tok panic!("#[pyclass] can only be used with C-style structs") } - impl_class(&class.ident, &base, doc, params, flags, descriptors) + impl_class( + &class.ident, + &attrs, + doc, + descriptors, + class.generics.clone(), + ) } fn parse_descriptors(item: &mut syn::Field) -> Vec { @@ -62,21 +77,40 @@ fn parse_descriptors(item: &mut syn::Field) -> Vec { fn impl_class( cls: &syn::Ident, - base: &syn::TypePath, + attrs: &PyClassAttributes, doc: syn::Lit, - params: HashMap<&'static str, syn::Expr>, - flags: Vec, descriptors: Vec<(syn::Field, Vec)>, + mut generics: syn::Generics, ) -> TokenStream { - let cls_name = match params.get("name") { - Some(name) => quote! { #name }.to_string(), + let cls_name = match attrs.name { + Some(ref name) => quote! { #name }.to_string(), None => quote! { #cls }.to_string(), }; + if attrs.variants.is_none() && generics.params.len() != 0 { + panic!( + "The `variants` parameter is required when using generic structs, \ + e.g. `#[pyclass(variants(\"{}U32\", \"{}F32\"))]`.", + cls_name, cls_name, + ); + } + + // Split generics into pieces for impls using them. + generics.make_where_clause(); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let mut where_clause = where_clause.unwrap().clone(); + + // Insert `MyStruct: PyTypeInfo` bound. + where_clause.predicates.push(parse_quote! { + #cls #ty_generics: ::pyo3::typeob::PyTypeInfo + }); + let extra = { - if let Some(freelist) = params.get("freelist") { + if let Some(ref freelist) = attrs.freelist { quote! { - impl ::pyo3::freelist::PyObjectWithFreeList for #cls { + impl #impl_generics ::pyo3::freelist::PyObjectWithFreeList + for #cls #ty_generics #where_clause + { #[inline] fn get_free_list() -> &'static mut ::pyo3::freelist::FreeList<*mut ::pyo3::ffi::PyObject> { static mut FREELIST: *mut ::pyo3::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _; @@ -85,7 +119,7 @@ fn impl_class( FREELIST = Box::into_raw(Box::new( ::pyo3::freelist::FreeList::with_capacity(#freelist))); - <#cls as ::pyo3::typeob::PyTypeCreate>::init_type(); + <#cls #ty_generics as ::pyo3::typeob::PyTypeCreate>::init_type(); } &mut *FREELIST } @@ -94,14 +128,20 @@ fn impl_class( } } else { quote! { - impl ::pyo3::typeob::PyObjectAlloc for #cls {} + impl #impl_generics ::pyo3::typeob::PyObjectAlloc for #cls #ty_generics #where_clause {} } } }; let extra = if !descriptors.is_empty() { let ty = syn::parse_str(&cls.to_string()).expect("no name"); - let desc_impls = impl_descriptors(&ty, descriptors); + let desc_impls = impl_descriptors( + &ty, + &impl_generics, + &ty_generics, + &where_clause, + descriptors, + ); quote! { #desc_impls #extra @@ -113,7 +153,7 @@ fn impl_class( // insert space for weak ref let mut has_weakref = false; let mut has_dict = false; - for f in flags.iter() { + for f in attrs.flags.iter() { if let syn::Expr::Path(ref epath) = f { if epath.path == parse_quote! {::pyo3::typeob::PY_TYPE_FLAG_WEAKREF} { has_weakref = true; @@ -133,25 +173,43 @@ fn impl_class( quote! {0} }; - quote! { - impl ::pyo3::typeob::PyTypeInfo for #cls { - type Type = #cls; + // Create a variant of our generics with lifetime 'a prepended. + let mut gen_with_a = generics.clone(); + gen_with_a.params.insert(0, parse_quote! { 'a }); + let impl_with_a = gen_with_a.split_for_impl().0; + + // Generate one PyTypeInfo per generic variant. + let variants= match attrs.variants { + Some(ref x) => x + .clone() + .into_iter() + .map(|(a, b)| (a, b.into_token_stream())) + .collect(), + None => vec![(cls_name, TokenStream::new())], + }; + + let base = &attrs.base; + let flags = &attrs.flags; + let type_info_impls: Vec<_> = variants.iter().map(|(name, for_ty)| quote! { + impl ::pyo3::typeob::PyTypeInfo for #cls #for_ty { + type Type = #cls #for_ty; type BaseType = #base; - const NAME: &'static str = #cls_name; + const NAME: &'static str = #name; const DESCRIPTION: &'static str = #doc; const FLAGS: usize = #(#flags)|*; const SIZE: usize = { Self::OFFSET as usize + - ::std::mem::size_of::<#cls>() + #weakref + #dict + ::std::mem::size_of::() + #weakref + #dict }; const OFFSET: isize = { // round base_size up to next multiple of align ( (<#base as ::pyo3::typeob::PyTypeInfo>::SIZE + - ::std::mem::align_of::<#cls>() - 1) / - ::std::mem::align_of::<#cls>() * ::std::mem::align_of::<#cls>() + ::std::mem::align_of::() - 1) / + ::std::mem::align_of::() * + ::std::mem::align_of::() ) as isize }; @@ -161,33 +219,38 @@ fn impl_class( &mut TYPE_OBJECT } } + }).collect(); + + quote! { + #(#type_info_impls)* // TBH I'm not sure what exactely this does and I'm sure there's a better way, // but for now it works and it only safe code and it is required to return custom // objects, so for now I'm keeping it - impl ::pyo3::IntoPyObject for #cls { + impl #impl_generics ::pyo3::IntoPyObject for #cls #ty_generics #where_clause { fn into_object(self, py: ::pyo3::Python) -> ::pyo3::PyObject { ::pyo3::Py::new(py, || self).unwrap().into_object(py) } } - impl ::pyo3::ToPyObject for #cls { + impl #impl_generics ::pyo3::ToPyObject for #cls #ty_generics #where_clause { fn to_object(&self, py: ::pyo3::Python) -> ::pyo3::PyObject { use ::pyo3::python::ToPyPointer; unsafe { ::pyo3::PyObject::from_borrowed_ptr(py, self.as_ptr()) } } } - impl ::pyo3::ToPyPointer for #cls { + impl #impl_generics ::pyo3::ToPyPointer for #cls #ty_generics #where_clause { fn as_ptr(&self) -> *mut ::pyo3::ffi::PyObject { unsafe { - {self as *const _ as *mut u8} - .offset(-<#cls as ::pyo3::typeob::PyTypeInfo>::OFFSET) as *mut ::pyo3::ffi::PyObject + (self as *const _ as *mut u8).offset( + -::OFFSET + ) as *mut ::pyo3::ffi::PyObject } } } - impl<'a> ::pyo3::ToPyObject for &'a mut #cls { + impl #impl_with_a ::pyo3::ToPyObject for &'a mut #cls #ty_generics #where_clause { fn to_object(&self, py: ::pyo3::Python) -> ::pyo3::PyObject { use ::pyo3::python::ToPyPointer; unsafe { ::pyo3::PyObject::from_borrowed_ptr(py, self.as_ptr()) } @@ -198,7 +261,13 @@ fn impl_class( } } -fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)>) -> TokenStream { +fn impl_descriptors( + cls: &syn::Type, + impl_generics: &syn::ImplGenerics, + ty_generics: &syn::TypeGenerics, + where_clause: &syn::WhereClause, + descriptors: Vec<(syn::Field, Vec)>, +) -> TokenStream { let methods: Vec = descriptors .iter() .flat_map(|&(ref field, ref fns)| { @@ -209,7 +278,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> match *desc { FnType::Getter(_) => { quote! { - impl #cls { + impl #impl_generics #cls #ty_generics #where_clause { fn #name(&self) -> ::pyo3::PyResult<#field_ty> { Ok(self.#name.clone()) } @@ -220,7 +289,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> let setter_name = syn::Ident::new(&format!("set_{}", name), Span::call_site()); quote! { - impl #cls { + impl #impl_generics #cls #ty_generics #where_clause { fn #setter_name(&mut self, value: #field_ty) -> ::pyo3::PyResult<()> { self.#name = value; Ok(()) @@ -284,7 +353,9 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> quote! { #(#methods)* - impl ::pyo3::class::methods::PyPropMethodsProtocolImpl for #cls { + impl #impl_generics ::pyo3::class::methods::PyPropMethodsProtocolImpl + for #cls #ty_generics #where_clause + { fn py_methods() -> &'static [::pyo3::class::PyMethodDefType] { static METHODS: &'static [::pyo3::class::PyMethodDefType] = &[ #(#py_methods),* @@ -295,23 +366,21 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec)> } } -fn parse_attribute( - args: &Vec, -) -> ( - HashMap<&'static str, syn::Expr>, - Vec, - syn::TypePath, -) { - let mut params = HashMap::new(); - // We need the 0 as value for the constant we're later building using quote for when there - // are no other flags - let mut flags = vec![parse_quote! {0}]; - let mut base: syn::TypePath = parse_quote! {::pyo3::types::PyObjectRef}; +fn parse_attribute(args: &Vec) -> PyClassAttributes { + use syn::Expr::*; + + let mut attrs = PyClassAttributes { + // We need the 0 as value for the constant we're later building using + // quote for when there are no other flags + flags: vec![parse_quote! {0}], + base: Some(parse_quote! {::pyo3::types::PyObjectRef}), + ..Default::default() + }; for expr in args.iter() { match expr { // Match a single flag - syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => { + Path(ref exp) if exp.path.segments.len() == 1 => { let flag = exp.path.segments.first().unwrap().value().ident.to_string(); let path = match flag.as_str() { "gc" => { @@ -329,13 +398,13 @@ fn parse_attribute( param => panic!("Unsupported parameter: {}", param), }; - flags.push(syn::Expr::Path(path)); + attrs.flags.push(Path(path)); } // Match a key/value flag - syn::Expr::Assign(ref ass) => { + Assign(ref ass) => { let key = match *ass.left { - syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => { + Path(ref exp) if exp.path.segments.len() == 1 => { exp.path.segments.first().unwrap().value().ident.to_string() } _ => panic!("could not parse argument: {:?}", ass), @@ -344,20 +413,20 @@ fn parse_attribute( match key.as_str() { "freelist" => { // TODO: check if int literal - params.insert("freelist", *ass.right.clone()); + attrs.freelist = Some(*ass.right.clone()); } "name" => match *ass.right { - syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => { - params.insert("name", exp.clone().into()); + Path(ref exp) if exp.path.segments.len() == 1 => { + attrs.name = Some(exp.clone().into()); } _ => panic!("Wrong 'name' format: {:?}", *ass.right), }, "extends" => match *ass.right { - syn::Expr::Path(ref exp) => { - base = syn::TypePath { + Path(ref exp) => { + attrs.base = Some(syn::TypePath { path: exp.path.clone(), qself: None, - }; + }); } _ => panic!("Wrong 'base' format: {:?}", *ass.right), }, @@ -367,9 +436,14 @@ fn parse_attribute( } } - _ => panic!("could not parse arguments"), + // Match variants (e.g. `variants("MyTypeU32", "MyTypeF32")`) + Call(ref call) => { + attrs.variants = Some(utils::parse_variants(call)); + } + + _ => panic!("Could not parse arguments"), } } - (params, flags, base) + attrs } diff --git a/pyo3-derive-backend/src/py_impl.rs b/pyo3-derive-backend/src/py_impl.rs index a20e26a6d36..040d18f6ccc 100644 --- a/pyo3-derive-backend/src/py_impl.rs +++ b/pyo3-derive-backend/src/py_impl.rs @@ -1,42 +1,118 @@ // Copyright (c) 2017-present PyO3 Project and Contributors +use method::FnSpec; use proc_macro2::TokenStream; use py_method; use syn; +use utils; -pub fn build_py_methods(ast: &mut syn::ItemImpl) -> TokenStream { +pub fn build_py_methods(ast: &mut syn::ItemImpl, attrs: &Vec) -> TokenStream { if ast.trait_.is_some() { panic!("#[pymethods] can not be used only with trait impl block"); - } else if ast.generics != Default::default() { - panic!("#[pymethods] can not ve used with lifetime parameters or generics"); - } else { - impl_methods(&ast.self_ty, &mut ast.items) } + + impl_methods(&ast.self_ty, &mut ast.items, attrs, &ast.generics) } -pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> TokenStream { - // get method names in impl block - let mut methods = Vec::new(); - for iimpl in impls.iter_mut() { - if let syn::ImplItem::Method(ref mut meth) = iimpl { - let name = meth.sig.ident.clone(); - methods.push(py_method::gen_py_method( - ty, - &name, - &mut meth.sig, - &mut meth.attrs, - )); +pub fn impl_methods( + ty: &syn::Type, + impls: &mut Vec, + attrs: &Vec, + generics: &syn::Generics, +) -> TokenStream { + use syn::PathArguments::AngleBracketed; + + // If there are generics, we expect a `variants` directive. + let variants = if !generics.params.is_empty() { + if let Some(syn::Expr::Call(ref call)) = attrs.first() { + utils::parse_variants(call) + .into_iter() + .map(|(_, x)| AngleBracketed(x)) + .collect() + } else { + panic!("`variants` annotation is required when using generics"); } - } + } else { + vec![syn::PathArguments::None] + }; + + // Parse method info. + let untouched_impls = impls.clone(); + let fn_specs: Vec<_> = impls + .iter_mut() + .filter_map(|x| match x { + syn::ImplItem::Method(meth) => { + Some(FnSpec::parse(&meth.sig.ident, &meth.sig, &mut meth.attrs)) + } + _ => None, + }) + .collect(); + + // Emit one `PyMethodsProtocolImpl` impl for each variant. + let impls = variants.into_iter().map(|variant_args| { + // Replace generic path arguments with concrete variant type arguments and generate + // `type T1 = ConcreteT1` statements for use in the wrapper methods. + // + // Why do aliasing instead of just replacing the types in the arg and return types, you may + // ask. I originally wrote a function recursively traversing and replacing generic types in + // all arguments and the return val, however it turned out to be a pretty complex beast + // that would also be guaranteed to be a burden in maintenance to keep up with all Rust + // syntax additions. This simple aliasing approach doesn't have these problems. + let mut variant_ty = ty.clone(); + let ty_map = if let syn::Type::Path(syn::TypePath { ref mut path, .. }) = variant_ty { + let tail = path.segments.iter_mut().last().unwrap(); + let generic_args = std::mem::replace(&mut tail.arguments, variant_args); - quote! { - impl ::pyo3::class::methods::PyMethodsProtocolImpl for #ty { - fn py_methods() -> &'static [::pyo3::class::PyMethodDefType] { - static METHODS: &'static [::pyo3::class::PyMethodDefType] = &[ - #(#methods),* - ]; - METHODS + match (&generic_args, &mut tail.arguments) { + (AngleBracketed(generic), AngleBracketed(variant)) => { + // Some generated methods require the type in turbo-fish syntax. + variant.colon2_token = parse_quote! { :: }; + + generic + .args + .iter() + .zip(variant.args.iter()) + .map(|(a, b)| { + quote! { + type #a = #b; + } + }) + .collect() + } + _ => Vec::new(), + } + } else { + Vec::new() + }; + + // Generate wrappers for Python methods. + let mut methods = Vec::new(); + for (iimpl, fn_spec) in untouched_impls.iter().zip(&fn_specs) { + if let syn::ImplItem::Method(meth) = iimpl { + methods.push(py_method::gen_py_method( + &variant_ty, + &meth.sig.ident, + &meth.sig, + &meth.attrs, + fn_spec, + )); } } - } + + // Emit the `PyMethodsProtocolImpl` impl for this struct variant. + quote! { + impl ::pyo3::class::methods::PyMethodsProtocolImpl for #variant_ty { + fn py_methods() -> &'static [::pyo3::class::PyMethodDefType] { + #(#ty_map)* + static METHODS: &'static [::pyo3::class::PyMethodDefType] = &[ + #(#methods),* + ]; + METHODS + } + } + } + }); + + // Merge everything. + quote! { #(#impls)* } } diff --git a/pyo3-derive-backend/src/py_method.rs b/pyo3-derive-backend/src/py_method.rs index 1e641405d49..c70447b537d 100644 --- a/pyo3-derive-backend/src/py_method.rs +++ b/pyo3-derive-backend/src/py_method.rs @@ -10,23 +10,80 @@ use utils; pub fn gen_py_method<'a>( cls: &syn::Type, name: &syn::Ident, - sig: &mut syn::MethodSig, - meth_attrs: &mut Vec, + sig: &syn::MethodSig, + meth_attrs: &Vec, + spec: &FnSpec, ) -> TokenStream { check_generic(name, sig); let doc = utils::get_doc(&meth_attrs, true); - let spec = FnSpec::parse(name, sig, meth_attrs); + + macro_rules! make_py_method_def { + ($def_type:ident, $meth_type:ident, $flags:expr, $wrapper:expr $(,)*) => {{ + let wrapper = $wrapper; + quote! { + ::pyo3::class::PyMethodDefType::$def_type({ + #wrapper + + ::pyo3::class::PyMethodDef { + ml_name: stringify!(#name), + ml_meth: ::pyo3::class::PyMethodType::$meth_type(__wrap), + ml_flags: $flags, + ml_doc: #doc, + } + }) + } + }}; + } match spec.tp { - FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)), - FnType::FnNew => impl_py_method_def_new(name, doc, &impl_wrap_new(cls, name, &spec)), - FnType::FnInit => impl_py_method_def_init(name, doc, &impl_wrap_init(cls, name, &spec)), - FnType::FnCall => impl_py_method_def_call(name, doc, &impl_wrap(cls, name, &spec, false)), - FnType::FnClass => impl_py_method_def_class(name, doc, &impl_wrap_class(cls, name, &spec)), - FnType::FnStatic => { - impl_py_method_def_static(name, doc, &impl_wrap_static(cls, name, &spec)) + FnType::Fn => { + if spec.args.is_empty() { + make_py_method_def!( + Method, + PyNoArgsFunction, + ::pyo3::ffi::METH_NOARGS, + &impl_wrap(cls, name, &spec, true), + ) + } else { + make_py_method_def!( + Method, + PyCFunctionWithKeywords, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, + &impl_wrap(cls, name, &spec, true), + ) + } } + FnType::FnNew => make_py_method_def!( + New, + PyNewFunc, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, + &impl_wrap_new(cls, name, &spec), + ), + FnType::FnInit => make_py_method_def!( + Init, + PyInitFunc, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, + &impl_wrap_init(cls, name, &spec), + ), + FnType::FnCall => make_py_method_def!( + Call, + PyCFunctionWithKeywords, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, + &impl_wrap(cls, name, &spec, false), + ), + FnType::FnClass => make_py_method_def!( + Class, + PyCFunctionWithKeywords, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS | ::pyo3::ffi::METH_CLASS, + &impl_wrap_class(cls, name, &spec), + ), + FnType::FnStatic => make_py_method_def!( + Static, + PyCFunctionWithKeywords, + ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS | ::pyo3::ffi::METH_STATIC, + &impl_wrap_static(cls, name, &spec), + ), FnType::Getter(ref getter) => { impl_py_getter_def(name, doc, getter, &impl_wrap_getter(cls, name)) } @@ -81,7 +138,8 @@ pub fn impl_wrap(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec, noargs: bool unsafe extern "C" fn __wrap( _slf: *mut ::pyo3::ffi::PyObject, _args: *mut ::pyo3::ffi::PyObject, - _kwargs: *mut ::pyo3::ffi::PyObject) -> *mut ::pyo3::ffi::PyObject + _kwargs: *mut ::pyo3::ffi::PyObject, + ) -> *mut ::pyo3::ffi::PyObject { const _LOCATION: &'static str = concat!( stringify!(#cls), ".", stringify!(#name), "()"); @@ -334,7 +392,11 @@ pub(crate) fn impl_wrap_getter(cls: &syn::Type, name: &syn::Ident) -> TokenStrea /// Generate functiona wrapper (PyCFunction, PyCFunctionWithKeywords) pub(crate) fn impl_wrap_setter(cls: &syn::Type, name: &syn::Ident, spec: &FnSpec) -> TokenStream { if spec.args.len() < 1 { - println!("Not enough arguments for setter {}::{}", quote!{#cls}, name); + println!( + "Not enough arguments for setter {}::{}", + quote! {#cls}, + name + ); } let val_ty = spec.args[0].ty; @@ -555,137 +617,6 @@ fn impl_arg_param(arg: &FnArg, spec: &FnSpec, body: &TokenStream, idx: usize) -> } } -pub fn impl_py_method_def( - name: &syn::Ident, - doc: syn::Lit, - spec: &FnSpec, - wrapper: &TokenStream, -) -> TokenStream { - if spec.args.is_empty() { - quote! { - ::pyo3::class::PyMethodDefType::Method({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyNoArgsFunction(__wrap), - ml_flags: ::pyo3::ffi::METH_NOARGS, - ml_doc: #doc, - } - }) - } - } else { - quote! { - ::pyo3::class::PyMethodDefType::Method({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, - ml_doc: #doc, - } - }) - } - } -} - -pub fn impl_py_method_def_new( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { - quote! { - ::pyo3::class::PyMethodDefType::New({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyNewFunc(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, - ml_doc: #doc, - } - }) - } -} - -pub fn impl_py_method_def_init( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { - quote! { - ::pyo3::class::PyMethodDefType::Init({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyInitFunc(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, - ml_doc: #doc, - } - }) - } -} - -pub fn impl_py_method_def_class( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { - quote! { - ::pyo3::class::PyMethodDefType::Class({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS | - ::pyo3::ffi::METH_CLASS, - ml_doc: #doc, - } - }) - } -} - -pub fn impl_py_method_def_static( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { - quote! { - ::pyo3::class::PyMethodDefType::Static({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS | ::pyo3::ffi::METH_STATIC, - ml_doc: #doc, - } - }) - } -} - -pub fn impl_py_method_def_call( - name: &syn::Ident, - doc: syn::Lit, - wrapper: &TokenStream, -) -> TokenStream { - quote! { - ::pyo3::class::PyMethodDefType::Call({ - #wrapper - - ::pyo3::class::PyMethodDef { - ml_name: stringify!(#name), - ml_meth: ::pyo3::class::PyMethodType::PyCFunctionWithKeywords(__wrap), - ml_flags: ::pyo3::ffi::METH_VARARGS | ::pyo3::ffi::METH_KEYWORDS, - ml_doc: #doc, - } - }) - } -} - pub(crate) fn impl_py_setter_def( name: &syn::Ident, doc: syn::Lit, diff --git a/pyo3-derive-backend/src/utils.rs b/pyo3-derive-backend/src/utils.rs index b013a0b8a6d..31e50e07dd8 100644 --- a/pyo3-derive-backend/src/utils.rs +++ b/pyo3-derive-backend/src/utils.rs @@ -1,5 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use syn; +use syn::parse::Parser; +use syn::punctuated::Punctuated; use proc_macro2::TokenStream; @@ -7,6 +9,70 @@ pub fn print_err(msg: String, t: TokenStream) { println!("Error: {} in '{}'", msg, t.to_string()); } +/// Parse the macro arguments into a list of expressions. +pub fn parse_attrs(tokens: proc_macro::TokenStream) -> Vec { + let parser = Punctuated::::parse_terminated; + let error_message = "The macro attributes should be a list of comma separated expressions"; + + parser + .parse(tokens) + .expect(error_message) + .into_iter() + .collect() +} + +/// Parses variant attributes like `variants("MyTypeU32", "MyTypeF32")` into pairs +/// of names and type arguments. +pub fn parse_variants(call: &syn::ExprCall) -> Vec<(String, syn::AngleBracketedGenericArguments)> { + use syn::Expr::*; + + let path = match *call.func { + Path(ref expr_path) => expr_path, + _ => panic!("Unsupported argument syntax"), + }; + let path_segments = &path.path.segments; + + if path_segments.len() != 1 + || path_segments.first().unwrap().value().ident.to_string() != "variants" + { + panic!("Unsupported argument syntax"); + } + + call.args + .iter() + .map(|x| { + // Extract string argument. + let lit = match x { + Lit(syn::ExprLit { + lit: syn::Lit::Str(ref lit), + .. + }) => lit.value(), + _ => panic!("Unsupported argument syntax"), + }; + + // Parse string as type. + let ty: syn::Type = syn::parse_str(&lit).expect("Invalid type definition"); + + let path_segs = match ty { + syn::Type::Path(syn::TypePath { ref path, .. }) => path.segments.clone(), + _ => panic!("Unsupported type syntax"), + }; + + if path_segs.len() != 1 { + panic!("Type path is expected to have exactly one segment."); + } + + let seg = path_segs.iter().nth(0).unwrap(); + let args = match seg.arguments { + syn::PathArguments::AngleBracketed(ref args) => args.clone(), + _ => panic!("Expected angle bracketed type arguments"), + }; + + (seg.ident.to_string(), args) + }) + .collect() +} + // FIXME(althonos): not sure the docstring formatting is on par here. pub fn get_doc(attrs: &Vec, null_terminated: bool) -> syn::Lit { let mut doc = Vec::new(); diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 08b98f1892c..259fa0faeb3 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -9,13 +9,10 @@ extern crate proc_macro2; extern crate pyo3_derive_backend; #[macro_use] extern crate quote; -#[macro_use] extern crate syn; use proc_macro2::Span; use pyo3_derive_backend::*; -use syn::parse::Parser; -use syn::punctuated::Punctuated; #[proc_macro_attribute] pub fn pymodule2( @@ -100,15 +97,7 @@ pub fn pyclass( // Parse the token stream into a syntax tree let mut ast: syn::ItemStruct = syn::parse(input).expect("#[pyclass] must be used on a `struct`"); - - // Parse the macro arguments into a list of expressions - let parser = Punctuated::::parse_terminated; - let error_message = "The macro attributes should be a list of comma separated expressions"; - let args = parser - .parse(attr) - .expect(error_message) - .into_iter() - .collect(); + let args = utils::parse_attrs(attr); // Build the output let expanded = py_class::build_py_class(&mut ast, &args); @@ -122,15 +111,16 @@ pub fn pyclass( #[proc_macro_attribute] pub fn pymethods( - _: proc_macro::TokenStream, + attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { // Parse the token stream into a syntax tree let mut ast: syn::ItemImpl = syn::parse(input.clone()).expect("#[pymethods] must be used on an `impl` block"); + let args = utils::parse_attrs(attr); // Build the output - let expanded = py_impl::build_py_methods(&mut ast); + let expanded = py_impl::build_py_methods(&mut ast, &args); quote!( #ast diff --git a/src/typeob.rs b/src/typeob.rs index 1af18a47379..d2c1955969f 100644 --- a/src/typeob.rs +++ b/src/typeob.rs @@ -54,7 +54,7 @@ pub trait PyTypeInfo { } } -/// type object supports python GC +/// Type object supports python GC pub const PY_TYPE_FLAG_GC: usize = 1; /// Type object supports python weak references diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 70a4e76a0e5..c494da7b631 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -75,4 +75,4 @@ fn empty_class_in_module() { .unwrap(), "test_module.nested" ); -} +} \ No newline at end of file diff --git a/tests/test_class_generic.rs b/tests/test_class_generic.rs new file mode 100644 index 00000000000..5e8ed0201c5 --- /dev/null +++ b/tests/test_class_generic.rs @@ -0,0 +1,125 @@ +#![feature(specialization)] + +extern crate pyo3; +use pyo3::prelude::*; + +#[macro_use] +mod common; + +#[pyclass(variants("SimpleGenericU32", "SimpleGenericF32"))] +struct SimpleGeneric { + _foo: T, +} + +#[test] +fn generic_names() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let ty_u32 = py.get_type::>(); + py_assert!(py, ty_u32, "ty_u32.__name__ == 'SimpleGenericU32'"); + + let ty_f32 = py.get_type::>(); + py_assert!(py, ty_f32, "ty_f32.__name__ == 'SimpleGenericF32'"); +} + +#[test] +fn generic_type_eq() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let tup = ( + (SimpleGeneric { _foo: 1u32 }).into_object(py), + (SimpleGeneric { _foo: 1u32 }).into_object(py), + (SimpleGeneric { _foo: 1f32 }).into_object(py), + (SimpleGeneric { _foo: 1f32 }).into_object(py), + ); + + py_assert!(py, tup, "type(tup[0]) == type(tup[1])"); + py_assert!(py, tup, "type(tup[2]) == type(tup[3])"); + py_assert!(py, tup, "type(tup[0]) != type(tup[2])"); +} + +#[pyclass(variants("GenericSquarerU64", "GenericSquarerF64"))] +struct GenericSquarer +where + T: std::ops::Mul + Copy + 'static, +{ + val: T, +} + +#[pymethods(variants("GenericSquarerU64", "GenericSquarerF64"))] +impl GenericSquarer +where + T: std::ops::Mul + Copy + 'static, + // #[pyclass] only implements `PyTypeInfo` for the given variants, + // so we need this constraint below. + GenericSquarer: pyo3::typeob::PyTypeInfo, +{ + #[new] + fn __new__(obj: &PyRawObject, val: T) -> PyResult<()> { + obj.init(|| Self { val }) + } + + #[getter] + fn get_val(&self) -> PyResult { + Ok(self.val) + } + + #[setter] + fn set_val(&mut self, value: T) -> PyResult<()> { + self.val = value; + Ok(()) + } + + fn square(&self) -> PyResult { + Ok(self.val * self.val) + } +} + +#[test] +fn generic_squarer() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let u64_squarer = py.init(|| GenericSquarer { val: 42u64 }).unwrap(); + py_assert!( + py, + u64_squarer, + "type(u64_squarer).__name__ == 'GenericSquarerU64'" + ); + py_assert!(py, u64_squarer, "u64_squarer.square() == 42 ** 2"); + + let f64_squarer = py.init(|| GenericSquarer { val: 42f64 }).unwrap(); + py_assert!( + py, + f64_squarer, + "type(f64_squarer).__name__ == 'GenericSquarerF64'" + ); + py_assert!(py, f64_squarer, "f64_squarer.square() == 42. ** 2."); +} + +#[test] +fn generic_squarer_getter_setter() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let u64_squarer = py.init(|| GenericSquarer { val: 1337u64 }).unwrap(); + py_assert!(py, u64_squarer, "u64_squarer.val == 1337"); + py_run!(py, u64_squarer, "u64_squarer.val = 42"); + py_assert!(py, u64_squarer, "u64_squarer.val == 42"); + py_assert!(py, u64_squarer, "u64_squarer.square() == 42 ** 2"); +} + +#[test] +fn generic_squarer_pynew() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let u64_squarer_ty = py.get_type::>(); + py_assert!( + py, + u64_squarer_ty, + "u64_squarer_ty(111).square() == 111 ** 2" + ); +}