Skip to content

Support for type arguments in pyclass #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 129 additions & 55 deletions pyo3-derive-backend/src/py_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@
use method::{FnArg, FnSpec, FnType};
use proc_macro2::{Span, TokenStream};
use py_method::{impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter};
use std::collections::HashMap;
use quote::ToTokens;
use syn;
use utils;

#[derive(Default, Debug)]
struct PyClassAttributes {
flags: Vec<syn::Expr>,
freelist: Option<syn::Expr>,
name: Option<syn::Expr>,
base: Option<syn::TypePath>,
variants: Option<Vec<(String, syn::AngleBracketedGenericArguments)>>,
}

pub fn build_py_class(class: &mut syn::ItemStruct, attr: &Vec<syn::Expr>) -> TokenStream {
let (params, flags, base) = parse_attribute(attr);
let attrs = parse_attribute(attr);
let doc = utils::get_doc(&class.attrs, true);
let mut descriptors = Vec::new();

Expand All @@ -23,7 +32,13 @@ pub fn build_py_class(class: &mut syn::ItemStruct, attr: &Vec<syn::Expr>) -> Tok
panic!("#[pyclass] can only be used with C-style structs")
}

impl_class(&class.ident, &base, doc, params, flags, descriptors)
impl_class(
&class.ident,
&attrs,
doc,
descriptors,
class.generics.clone(),
)
}

fn parse_descriptors(item: &mut syn::Field) -> Vec<FnType> {
Expand Down Expand Up @@ -62,21 +77,40 @@ fn parse_descriptors(item: &mut syn::Field) -> Vec<FnType> {

fn impl_class(
cls: &syn::Ident,
base: &syn::TypePath,
attrs: &PyClassAttributes,
doc: syn::Lit,
params: HashMap<&'static str, syn::Expr>,
flags: Vec<syn::Expr>,
descriptors: Vec<(syn::Field, Vec<FnType>)>,
mut generics: syn::Generics,
) -> TokenStream {
let cls_name = match params.get("name") {
Some(name) => quote! { #name }.to_string(),
let cls_name = match attrs.name {
Some(ref name) => quote! { #name }.to_string(),
None => quote! { #cls }.to_string(),
};

if attrs.variants.is_none() && generics.params.len() != 0 {
panic!(
"The `variants` parameter is required when using generic structs, \
e.g. `#[pyclass(variants(\"{}U32<u32>\", \"{}F32<f32>\"))]`.",
cls_name, cls_name,
);
}

// Split generics into pieces for impls using them.
generics.make_where_clause();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut where_clause = where_clause.unwrap().clone();

// Insert `MyStruct<T>: PyTypeInfo` bound.
where_clause.predicates.push(parse_quote! {
#cls #ty_generics: ::pyo3::typeob::PyTypeInfo
});

let extra = {
if let Some(freelist) = params.get("freelist") {
if let Some(ref freelist) = attrs.freelist {
quote! {
impl ::pyo3::freelist::PyObjectWithFreeList for #cls {
impl #impl_generics ::pyo3::freelist::PyObjectWithFreeList
for #cls #ty_generics #where_clause
{
#[inline]
fn get_free_list() -> &'static mut ::pyo3::freelist::FreeList<*mut ::pyo3::ffi::PyObject> {
static mut FREELIST: *mut ::pyo3::freelist::FreeList<*mut ::pyo3::ffi::PyObject> = 0 as *mut _;
Expand All @@ -85,7 +119,7 @@ fn impl_class(
FREELIST = Box::into_raw(Box::new(
::pyo3::freelist::FreeList::with_capacity(#freelist)));

<#cls as ::pyo3::typeob::PyTypeCreate>::init_type();
<#cls #ty_generics as ::pyo3::typeob::PyTypeCreate>::init_type();
}
&mut *FREELIST
}
Expand All @@ -94,14 +128,20 @@ fn impl_class(
}
} else {
quote! {
impl ::pyo3::typeob::PyObjectAlloc for #cls {}
impl #impl_generics ::pyo3::typeob::PyObjectAlloc for #cls #ty_generics #where_clause {}
}
}
};

let extra = if !descriptors.is_empty() {
let ty = syn::parse_str(&cls.to_string()).expect("no name");
let desc_impls = impl_descriptors(&ty, descriptors);
let desc_impls = impl_descriptors(
&ty,
&impl_generics,
&ty_generics,
&where_clause,
descriptors,
);
quote! {
#desc_impls
#extra
Expand All @@ -113,7 +153,7 @@ fn impl_class(
// insert space for weak ref
let mut has_weakref = false;
let mut has_dict = false;
for f in flags.iter() {
for f in attrs.flags.iter() {
if let syn::Expr::Path(ref epath) = f {
if epath.path == parse_quote! {::pyo3::typeob::PY_TYPE_FLAG_WEAKREF} {
has_weakref = true;
Expand All @@ -133,25 +173,43 @@ fn impl_class(
quote! {0}
};

quote! {
impl ::pyo3::typeob::PyTypeInfo for #cls {
type Type = #cls;
// Create a variant of our generics with lifetime 'a prepended.
let mut gen_with_a = generics.clone();
gen_with_a.params.insert(0, parse_quote! { 'a });
let impl_with_a = gen_with_a.split_for_impl().0;

// Generate one PyTypeInfo per generic variant.
let variants= match attrs.variants {
Some(ref x) => x
.clone()
.into_iter()
.map(|(a, b)| (a, b.into_token_stream()))
.collect(),
None => vec![(cls_name, TokenStream::new())],
};

let base = &attrs.base;
let flags = &attrs.flags;
let type_info_impls: Vec<_> = variants.iter().map(|(name, for_ty)| quote! {
impl ::pyo3::typeob::PyTypeInfo for #cls #for_ty {
type Type = #cls #for_ty;
type BaseType = #base;

const NAME: &'static str = #cls_name;
const NAME: &'static str = #name;
const DESCRIPTION: &'static str = #doc;
const FLAGS: usize = #(#flags)|*;

const SIZE: usize = {
Self::OFFSET as usize +
::std::mem::size_of::<#cls>() + #weakref + #dict
::std::mem::size_of::<Self>() + #weakref + #dict
};
const OFFSET: isize = {
// round base_size up to next multiple of align
(
(<#base as ::pyo3::typeob::PyTypeInfo>::SIZE +
::std::mem::align_of::<#cls>() - 1) /
::std::mem::align_of::<#cls>() * ::std::mem::align_of::<#cls>()
::std::mem::align_of::<Self>() - 1) /
::std::mem::align_of::<Self>() *
::std::mem::align_of::<Self>()
) as isize
};

Expand All @@ -161,33 +219,38 @@ fn impl_class(
&mut TYPE_OBJECT
}
}
}).collect();

quote! {
#(#type_info_impls)*

// TBH I'm not sure what exactely this does and I'm sure there's a better way,
// but for now it works and it only safe code and it is required to return custom
// objects, so for now I'm keeping it
impl ::pyo3::IntoPyObject for #cls {
impl #impl_generics ::pyo3::IntoPyObject for #cls #ty_generics #where_clause {
fn into_object(self, py: ::pyo3::Python) -> ::pyo3::PyObject {
::pyo3::Py::new(py, || self).unwrap().into_object(py)
}
}

impl ::pyo3::ToPyObject for #cls {
impl #impl_generics ::pyo3::ToPyObject for #cls #ty_generics #where_clause {
fn to_object(&self, py: ::pyo3::Python) -> ::pyo3::PyObject {
use ::pyo3::python::ToPyPointer;
unsafe { ::pyo3::PyObject::from_borrowed_ptr(py, self.as_ptr()) }
}
}

impl ::pyo3::ToPyPointer for #cls {
impl #impl_generics ::pyo3::ToPyPointer for #cls #ty_generics #where_clause {
fn as_ptr(&self) -> *mut ::pyo3::ffi::PyObject {
unsafe {
{self as *const _ as *mut u8}
.offset(-<#cls as ::pyo3::typeob::PyTypeInfo>::OFFSET) as *mut ::pyo3::ffi::PyObject
(self as *const _ as *mut u8).offset(
-<Self as ::pyo3::typeob::PyTypeInfo>::OFFSET
) as *mut ::pyo3::ffi::PyObject
}
}
}

impl<'a> ::pyo3::ToPyObject for &'a mut #cls {
impl #impl_with_a ::pyo3::ToPyObject for &'a mut #cls #ty_generics #where_clause {
fn to_object(&self, py: ::pyo3::Python) -> ::pyo3::PyObject {
use ::pyo3::python::ToPyPointer;
unsafe { ::pyo3::PyObject::from_borrowed_ptr(py, self.as_ptr()) }
Expand All @@ -198,7 +261,13 @@ fn impl_class(
}
}

fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec<FnType>)>) -> TokenStream {
fn impl_descriptors(
cls: &syn::Type,
impl_generics: &syn::ImplGenerics,
ty_generics: &syn::TypeGenerics,
where_clause: &syn::WhereClause,
descriptors: Vec<(syn::Field, Vec<FnType>)>,
) -> TokenStream {
let methods: Vec<TokenStream> = descriptors
.iter()
.flat_map(|&(ref field, ref fns)| {
Expand All @@ -209,7 +278,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec<FnType>)>
match *desc {
FnType::Getter(_) => {
quote! {
impl #cls {
impl #impl_generics #cls #ty_generics #where_clause {
fn #name(&self) -> ::pyo3::PyResult<#field_ty> {
Ok(self.#name.clone())
}
Expand All @@ -220,7 +289,7 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec<FnType>)>
let setter_name =
syn::Ident::new(&format!("set_{}", name), Span::call_site());
quote! {
impl #cls {
impl #impl_generics #cls #ty_generics #where_clause {
fn #setter_name(&mut self, value: #field_ty) -> ::pyo3::PyResult<()> {
self.#name = value;
Ok(())
Expand Down Expand Up @@ -284,7 +353,9 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec<FnType>)>
quote! {
#(#methods)*

impl ::pyo3::class::methods::PyPropMethodsProtocolImpl for #cls {
impl #impl_generics ::pyo3::class::methods::PyPropMethodsProtocolImpl
for #cls #ty_generics #where_clause
{
fn py_methods() -> &'static [::pyo3::class::PyMethodDefType] {
static METHODS: &'static [::pyo3::class::PyMethodDefType] = &[
#(#py_methods),*
Expand All @@ -295,23 +366,21 @@ fn impl_descriptors(cls: &syn::Type, descriptors: Vec<(syn::Field, Vec<FnType>)>
}
}

fn parse_attribute(
args: &Vec<syn::Expr>,
) -> (
HashMap<&'static str, syn::Expr>,
Vec<syn::Expr>,
syn::TypePath,
) {
let mut params = HashMap::new();
// We need the 0 as value for the constant we're later building using quote for when there
// are no other flags
let mut flags = vec![parse_quote! {0}];
let mut base: syn::TypePath = parse_quote! {::pyo3::types::PyObjectRef};
fn parse_attribute(args: &Vec<syn::Expr>) -> PyClassAttributes {
use syn::Expr::*;

let mut attrs = PyClassAttributes {
// We need the 0 as value for the constant we're later building using
// quote for when there are no other flags
flags: vec![parse_quote! {0}],
base: Some(parse_quote! {::pyo3::types::PyObjectRef}),
..Default::default()
};

for expr in args.iter() {
match expr {
// Match a single flag
syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => {
Path(ref exp) if exp.path.segments.len() == 1 => {
let flag = exp.path.segments.first().unwrap().value().ident.to_string();
let path = match flag.as_str() {
"gc" => {
Expand All @@ -329,13 +398,13 @@ fn parse_attribute(
param => panic!("Unsupported parameter: {}", param),
};

flags.push(syn::Expr::Path(path));
attrs.flags.push(Path(path));
}

// Match a key/value flag
syn::Expr::Assign(ref ass) => {
Assign(ref ass) => {
let key = match *ass.left {
syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => {
Path(ref exp) if exp.path.segments.len() == 1 => {
exp.path.segments.first().unwrap().value().ident.to_string()
}
_ => panic!("could not parse argument: {:?}", ass),
Expand All @@ -344,20 +413,20 @@ fn parse_attribute(
match key.as_str() {
"freelist" => {
// TODO: check if int literal
params.insert("freelist", *ass.right.clone());
attrs.freelist = Some(*ass.right.clone());
}
"name" => match *ass.right {
syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => {
params.insert("name", exp.clone().into());
Path(ref exp) if exp.path.segments.len() == 1 => {
attrs.name = Some(exp.clone().into());
}
_ => panic!("Wrong 'name' format: {:?}", *ass.right),
},
"extends" => match *ass.right {
syn::Expr::Path(ref exp) => {
base = syn::TypePath {
Path(ref exp) => {
attrs.base = Some(syn::TypePath {
path: exp.path.clone(),
qself: None,
};
});
}
_ => panic!("Wrong 'base' format: {:?}", *ass.right),
},
Expand All @@ -367,9 +436,14 @@ fn parse_attribute(
}
}

_ => panic!("could not parse arguments"),
// Match variants (e.g. `variants("MyTypeU32<u32>", "MyTypeF32<f32>")`)
Call(ref call) => {
attrs.variants = Some(utils::parse_variants(call));
}

_ => panic!("Could not parse arguments"),
}
}

(params, flags, base)
attrs
}
Loading