Skip to content

Add derive EnumCommonFields #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Provides traits and "derives" for enum items in the Rust programming language:
- EnumIter
- EnumIterator
- EnumVariantName
- EnumCommonFields

### Traits ###
- Index
Expand Down
29 changes: 29 additions & 0 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
//! Simple traits for builtin enum items.
//! Primarily used by `enum_traits_macros` when automatically deriving types.
//! The crate `enum_traits_macros` is required for the derives.
//!
//! ## EnumCommonFields
//!
//! Using `#[derive(EnumCommonFields)]` on enums containing exclusively struct variants
//! will implement member functions for all struct fields with the same name and type.
//!
//! ### Examples
//!
//! ```ignore
//! #[derive(Debug, EnumCommonFields)]
//! enum Enum {
//! Cat{age: u32},
//! Dog{age: u32},
//! Robot{age: u32},
//! }
//! assert_eq!(Enum::Dog{age: 3}.age(), &3);
//! ```
//!
//! ```ignore
//! #[derive(Debug, PartialEq, EnumCommonFields)]
//! enum Enum {
//! Cat{age: u32},
//! Dog{age: u32},
//! Robot{age: u32},
//! }
//! let mut d = Enum::Dog{age: 3};
//! *d.age_mut() = 5;
//! assert_eq!(d, Enum::Dog{age: 5});
//! ```

#![feature(associated_consts)]

Expand Down
95 changes: 95 additions & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,98 @@ pub fn derive_EnumVariantName(input: TokenStream) -> TokenStream {
}
derive_enum(input, gen_impl)
}

#[proc_macro_derive(EnumCommonFields)]
pub fn derive_EnumCommonFields(input: TokenStream) -> TokenStream {
fn gen_impl(ident: &Ident, item: &MacroInput, data: &Vec<Variant>) -> Tokens {
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let mut field_list: Vec<(Ident, syn::Ty)> = Vec::new();
// collect all struct fields
for variant in data {
match variant.data {
VariantData::Unit |
VariantData::Tuple(_) => {
panic!("`derive(EnumCommonFields)` may only be applied to enum items with struct variants")
}
VariantData::Struct(ref fields) => {
for field in fields.iter().filter(|f| f.ident.is_some()) {
let value_name = field.ident.as_ref().unwrap().to_owned();
let value_type = field.ty.clone();
if !field_list.iter().any(|f| &f.0 == &value_name) {
field_list.push((value_name, value_type));
}
}
}
}
}

// remove struct fields, that are not available in all variants
if !field_list.is_empty() {
for variant in data {
match variant.data {
VariantData::Unit |
VariantData::Tuple(_) => {
unreachable!()
}
VariantData::Struct(ref fields) => {
let mut local_field_list = Vec::new();
for field in fields.iter().filter(|f| f.ident.is_some()) {
let value_name = field.ident.as_ref().unwrap().to_owned();
let value_type = field.ty.clone();
local_field_list.push((value_name, value_type));
}
field_list.retain(|f| local_field_list.iter().any(|lf| f == lf));
}
}
}
}

if field_list.is_empty() {
panic!("`derive(EnumCommonFields)` may only be applied to enum items that share at least one common struct field")
}

// create functions
let functions = field_list.iter().map(|&(ref value_name, ref value_type)| {
let match_arms = data.iter().map(|variant| {
let variant_ident = &variant.ident;
quote! { #ident::#variant_ident{ref #value_name, ..} => #value_name, }
});

quote!{
pub fn #value_name(&self) -> &#value_type{
match *self {
#( #match_arms )*
}
}
}
});

let functions_mut = field_list.iter().map(|&(ref value_name, ref value_type)| {
let match_arms = data.iter().map(|variant| {
let variant_ident = &variant.ident;
quote! { #ident::#variant_ident{ref mut #value_name, ..} => #value_name, }
});

let value_name_mut = quote::Ident::from(format!("{}_mut", value_name));

quote!{
pub fn #value_name_mut(&mut self) -> &mut #value_type{
match *self {
#( #match_arms )*
}
}
}
});

quote!{
#[automatically_derived]
#[allow(unused)]
impl #impl_generics #ident #ty_generics #where_clause{
#( #functions )*
#( #functions_mut )*
}
}
}
derive_enum(input, gen_impl)
}
67 changes: 67 additions & 0 deletions tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,73 @@ fn variant_name() {
assert_eq!(Enum::Robot{speed: 0.0}.variant_name(), "Robot");
}

#[test]
#[allow(dead_code)]
fn common_fields() {
{
#[derive(Debug,EnumCommonFields)]
enum Enum {
Dog{age: u32},
Cat{age: u32},
Robot{age: u32},
}
assert_eq!(Enum::Dog{age: 3}.age(), &3);
}
{
#[derive(Debug,EnumCommonFields)]
enum Enum {
Dog{age: u32},
Cat{age: u32},
Robot{age: u32, speed: f32},
}
assert_eq!(Enum::Dog{age: 3}.age(), &3);
assert_eq!(Enum::Robot{age: 7, speed: 90.0}.age(), &7);
}
{
#[derive(Debug,EnumCommonFields)]
enum Enum {
Dog{age: u32, speed: f32},
Cat{age: u32, speed: f32},
Robot{age: u32, speed: f32},
}
assert_eq!(Enum::Dog{age: 3, speed: 20.0}.age(), &3);
assert_eq!(Enum::Robot{age: 7, speed: 90.0}.speed(), &90.0);
}
{
#[derive(Debug,EnumCommonFields)]
enum Enum {
Dog{age: u32},
Cat{age: u32},
Robot{age: u32},
}

impl Enum {
pub fn zero(&mut self) -> usize {
0
}
}

assert_eq!(Enum::Dog{age: 3}.age(), &3);
assert_eq!(Enum::Dog{age: 3}.zero(), 0);
}
}

#[test]
#[allow(dead_code)]
fn common_fields_mut() {
{
#[derive(Debug,PartialEq,EnumCommonFields)]
enum Enum {
Dog{age: u32},
Cat{age: u32},
Robot{age: u32},
}
let mut d = Enum::Dog{age: 3};
*d.age_mut() = 5;
assert_eq!(d, Enum::Dog{age: 5});
}
}

#[test]
#[allow(dead_code)]
fn f1(){
Expand Down