Skip to content

Commit 76d89d9

Browse files
committed
add_newtype_macro
1 parent fd9b5df commit 76d89d9

File tree

6 files changed

+221
-16
lines changed

6 files changed

+221
-16
lines changed

bindgen-macros/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use proc_macro::TokenStream;
1010

11+
mod newtype;
1112
mod udf;
1213
mod udt;
1314

@@ -20,3 +21,8 @@ pub fn user_defined_type(attrs: TokenStream, item: TokenStream) -> TokenStream {
2021
pub fn scylla_bindgen(attrs: TokenStream, item: TokenStream) -> TokenStream {
2122
udf::scylla_bindgen(attrs, item)
2223
}
24+
25+
#[proc_macro_attribute]
26+
pub fn scylla_newtype(attrs: TokenStream, item: TokenStream) -> TokenStream {
27+
newtype::scylla_newtype(attrs, item)
28+
}

bindgen-macros/src/newtype.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use proc_macro::TokenStream;
2+
use proc_macro2::TokenStream as TokenStream2;
3+
use quote::quote;
4+
use syn::Fields;
5+
6+
struct NewtypeStruct {
7+
struct_name: syn::Ident,
8+
field_type: syn::Type,
9+
generics: syn::Generics,
10+
}
11+
fn get_newtype_struct(st: &syn::ItemStruct) -> Result<NewtypeStruct, TokenStream2> {
12+
let struct_name = &st.ident;
13+
let struct_fields = match &st.fields {
14+
Fields::Unnamed(named_fields) => named_fields,
15+
_ => {
16+
return Err(syn::Error::new_spanned(
17+
st,
18+
"#[scylla_newtype] error: struct has named fields.",
19+
)
20+
.to_compile_error());
21+
}
22+
};
23+
if struct_fields.unnamed.len() > 1 {
24+
return Err(syn::Error::new_spanned(
25+
st,
26+
"#[scylla_newtype] error: struct has more than 1 field.",
27+
)
28+
.to_compile_error());
29+
}
30+
let field_type = match struct_fields.unnamed.first() {
31+
Some(field) => &field.ty,
32+
None => {
33+
return Err(syn::Error::new_spanned(
34+
st,
35+
"#[scylla_newtype] error: struct has no fields.",
36+
)
37+
.to_compile_error());
38+
}
39+
};
40+
41+
Ok(NewtypeStruct {
42+
struct_name: struct_name.clone(),
43+
field_type: field_type.clone(),
44+
generics: st.generics.clone(),
45+
})
46+
}
47+
48+
fn impl_wasm_convertible(nst: &NewtypeStruct) -> TokenStream2 {
49+
let struct_name = &nst.struct_name;
50+
let struct_type = &nst.field_type;
51+
let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl();
52+
quote! {
53+
impl #impl_generics ::scylla_bindgen::_macro_internal::WasmConvertible for #struct_name #ty_generics #where_clause {
54+
type WasmType = <#struct_type as ::scylla_bindgen::_macro_internal::WasmConvertible>::WasmType;
55+
unsafe fn from_wasmtype(arg: Self::WasmType) -> Self {
56+
#struct_name(<#struct_type as ::scylla_bindgen::_macro_internal::WasmConvertible>::from_wasmtype(arg))
57+
}
58+
fn as_wasmtype(&self) -> Self::WasmType {
59+
<#struct_type as ::scylla_bindgen::_macro_internal::WasmConvertible>::as_wasmtype(&self.0)
60+
}
61+
}
62+
}
63+
}
64+
65+
fn impl_to_col_type(nst: &NewtypeStruct) -> TokenStream2 {
66+
let struct_name = &nst.struct_name;
67+
let struct_type = &nst.field_type;
68+
let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl();
69+
quote! {
70+
impl #impl_generics ::scylla_bindgen::_macro_internal::ToColumnType for #struct_name #ty_generics #where_clause {
71+
fn to_column_type() -> ::scylla_bindgen::_macro_internal::ColumnType {
72+
#struct_type::to_column_type()
73+
}
74+
}
75+
}
76+
}
77+
fn impl_value(nst: &NewtypeStruct) -> TokenStream2 {
78+
let struct_name = &nst.struct_name;
79+
let struct_type = &nst.field_type;
80+
let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl();
81+
82+
quote! {
83+
impl #impl_generics ::scylla_bindgen::_macro_internal::Value for #struct_name #ty_generics #where_clause {
84+
fn serialize(&self, buf: &mut ::std::vec::Vec<::core::primitive::u8>) -> ::std::result::Result<(), ::scylla_bindgen::_macro_internal::ValueTooBig> {
85+
<#struct_type as ::scylla_bindgen::_macro_internal::Value>::serialize(&self.0, buf)
86+
}
87+
}
88+
}
89+
}
90+
91+
fn impl_from_cql_val(nst: &NewtypeStruct) -> TokenStream2 {
92+
let struct_name = &nst.struct_name;
93+
let struct_type = &nst.field_type;
94+
let (impl_generics, ty_generics, where_clause) = nst.generics.split_for_impl();
95+
96+
quote! {
97+
impl #impl_generics ::scylla_bindgen::_macro_internal::FromCqlVal<::scylla_bindgen::_macro_internal::CqlValue> for #struct_name #ty_generics #where_clause {
98+
fn from_cql(val: ::scylla_bindgen::_macro_internal::CqlValue) -> ::std::result::Result<Self, ::scylla_bindgen::_macro_internal::FromCqlValError> {
99+
<#struct_type as ::scylla_bindgen::_macro_internal::FromCqlVal<::scylla_bindgen::_macro_internal::CqlValue>>::from_cql(val).map(|v| #struct_name(v))
100+
}
101+
}
102+
}
103+
}
104+
105+
pub(crate) fn scylla_newtype(_attrs: TokenStream, item: TokenStream) -> TokenStream {
106+
let mut st = syn::parse_macro_input!(item as syn::ItemStruct);
107+
// We have to make the struct public, otherwise we can't use "<$struct_name as WasmConvertible>::WasmType"
108+
// in the exported function signature even if WasmType is public.
109+
st.vis = syn::Visibility::Public(syn::VisPublic {
110+
pub_token: syn::token::Pub(proc_macro2::Span::call_site()),
111+
});
112+
let newtype_struct = match get_newtype_struct(&st) {
113+
Ok(nst) => nst,
114+
Err(e) => return e.into(),
115+
};
116+
let wasm_convertible = impl_wasm_convertible(&newtype_struct);
117+
let to_col_type = impl_to_col_type(&newtype_struct);
118+
let value = impl_value(&newtype_struct);
119+
let from_cql_val = impl_from_cql_val(&newtype_struct);
120+
quote! {
121+
#st
122+
#wasm_convertible
123+
#to_col_type
124+
#value
125+
#from_cql_val
126+
}
127+
.into()
128+
}

bindgen/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,26 @@ pub use bindgen_macros::scylla_bindgen;
5959
/// ```
6060
pub use bindgen_macros::user_defined_type;
6161

62+
/// The macro takes a "newtype" struct (tuple struct with only one field) and generates all implementations for traits used in the
63+
/// scylla_bindgen macro by treating the struct as the inner type itself.
64+
/// For example, for a function using a newtype struct:
65+
/// ```
66+
/// #[scylla_newtype]
67+
/// struct MyInt(i32);
68+
///
69+
/// #[scylla_bindgen]
70+
/// fn foo(arg: MyInt) -> MyInt {
71+
/// ...
72+
/// }
73+
/// ```
74+
/// and a table:
75+
/// ```
76+
/// CREATE TABLE table (x int PRIMARY KEY);
77+
/// ```
78+
/// you can use the function in a query:
79+
/// ```
80+
/// SELECT foo(x) FROM table;
81+
/// ```
82+
pub use bindgen_macros::scylla_newtype;
83+
6284
pub use scylla_cql::frame::value::{Counter, CqlDuration, Time, Timestamp};

examples/fib.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
use scylla_bindgen::scylla_bindgen;
1+
use scylla_bindgen::*;
2+
3+
#[scylla_newtype]
4+
struct FibInputNumber(i32);
5+
6+
#[scylla_newtype]
7+
struct FibReturnNumber(i64);
8+
29
#[scylla_bindgen]
3-
fn fib(i: u32) -> u64 {
4-
if i <= 2 {
10+
fn fib(i: FibInputNumber) -> FibReturnNumber {
11+
FibReturnNumber(if i.0 <= 2 {
512
1
613
} else {
7-
fib(i - 1) + fib(i - 2)
8-
}
14+
fib(FibInputNumber(i.0 - 1)).0 + fib(FibInputNumber(i.0 - 2)).0
15+
})
916
}

examples/topn_reduce.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
11
use scylla_bindgen::scylla_bindgen;
22
use std::collections::BTreeSet;
3+
4+
#[scylla_newtype]
5+
struct StringLen(String);
6+
7+
impl std::cmp::PartialEq for StringLen {
8+
fn eq(&self, other: &Self) -> bool {
9+
self.0.len() == other.0.len()
10+
}
11+
}
12+
impl std::cmp::Eq for StringLen {}
13+
14+
impl std::cmp::PartialOrd for StringLen {
15+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
16+
Some(self.0.len().cmp(&other.0.len()))
17+
}
18+
}
19+
impl std::cmp::Ord for StringLen {
20+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
21+
self.0.len().cmp(&other.0.len())
22+
}
23+
}
24+
325
#[scylla_bindgen]
4-
fn top10_reduce(mut acc1: BTreeSet<i32>, mut acc2: BTreeSet<i32>) -> BTreeSet<i32> {
26+
fn topn_reduce((n1, mut acc1): (i32, BTreeSet<StringLen>), (_, mut acc1): (i32, BTreeSet<StringLen>)) -> (i32, BTreeSet<StringLen>) {
527
acc1.append(&mut acc2);
6-
while acc1.len() > 10 {
28+
while acc1.len() > n1 {
729
acc1.pop_first();
830
}
9-
acc1
31+
(n1, acc1)
1032
}

examples/topn_row.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
1-
use scylla_bindgen::scylla_bindgen;
1+
use scylla_bindgen::*;
22
use std::collections::BTreeSet;
3+
4+
#[scylla_newtype]
5+
struct StringLen(String);
6+
7+
impl std::cmp::PartialEq for StringLen {
8+
fn eq(&self, other: &Self) -> bool {
9+
self.0.len() == other.0.len()
10+
}
11+
}
12+
impl std::cmp::Eq for StringLen {}
13+
14+
impl std::cmp::PartialOrd for StringLen {
15+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
16+
Some(self.0.len().cmp(&other.0.len()))
17+
}
18+
}
19+
impl std::cmp::Ord for StringLen {
20+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
21+
self.0.len().cmp(&other.0.len())
22+
}
23+
}
24+
25+
// Store the top N strings by length, without repetitions.
326
#[scylla_bindgen]
4-
fn topn_row(
5-
(n, mut acc, mut idx): (i32, BTreeSet<(String, i32)>, i32),
6-
v: String,
7-
) -> (i32, BTreeSet<(String, i32)>, i32) {
8-
acc.insert((v, idx));
9-
idx += 1;
27+
fn topn_row((n, mut acc): (i32, BTreeSet<StringLen>), v: StringLen) -> (i32, BTreeSet<StringLen>) {
28+
acc.insert(v);
1029
while acc.len() > n as usize {
1130
acc.pop_first();
1231
}
13-
(n, acc, idx)
32+
33+
(n, acc)
1434
}

0 commit comments

Comments
 (0)