Skip to content

Commit eb79f34

Browse files
Add documentation for ColsRef macro
1 parent 12321d4 commit eb79f34

File tree

6 files changed

+533
-36
lines changed

6 files changed

+533
-36
lines changed

crates/circuits/primitives/derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ license.workspace = true
1212
proc-macro = true
1313

1414
[dependencies]
15-
syn = { version = "2.0", features = ["parsing"] }
15+
syn = { version = "2.0", features = ["parsing", "extra-traits"] }
1616
quote = "1.0"
1717
itertools = { workspace = true, default-features = true }
1818
proc-macro2 = "1.0"
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# ColsRef macro
2+
3+
The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes.
4+
5+
Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384).
6+
See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits.
7+
8+
## Overview
9+
10+
As an illustrative example, consider the following columns struct:
11+
```rust
12+
struct ExampleCols<T, const N: usize> {
13+
arr: [T; N],
14+
sum: T,
15+
}
16+
```
17+
Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10.
18+
We can define a trait that stores the config parameters.
19+
```rust
20+
pub trait ExampleConfig {
21+
const N: usize;
22+
}
23+
```
24+
and then implement it for the two different configs.
25+
```rust
26+
pub struct ExampleConfigImplA;
27+
impl ExampleConfig for ExampleConfigImplA {
28+
const N: usize = 5;
29+
}
30+
pub struct ExampleConfigImplB;
31+
impl ExampleConfig for ExampleConfigImplB {
32+
const N: usize = 10;
33+
}
34+
```
35+
Then we can use the `ColsRef` macro like this
36+
```rust
37+
#[derive(ColsRef)]
38+
#[config(ExampleConfig)]
39+
struct ExampleCols<T, const N: usize> {
40+
arr: [T; N],
41+
sum: T,
42+
}
43+
```
44+
which will generate a columns struct that uses references to the fields.
45+
```rust
46+
struct ExampleColsRef<'a, T, const N: usize> {
47+
arr: &'a [T; N],
48+
sum: &'a T,
49+
}
50+
```
51+
The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct.
52+
The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct.
53+
54+
So, the constraint generation code can be written as
55+
```rust
56+
impl<AB: InteractionBuilder, C: ExampleConfig> Air<AB> for ExampleAir<C> {
57+
fn eval(&self, builder: &mut AB) {
58+
let main = builder.main();
59+
let (local, _) = (main.row_slice(0), main.row_slice(1));
60+
let local_cols = ExampleColsRef::<AB::Var>::from::<C>(&local[..C::N + 1]);
61+
let sum = local_cols.arr.iter().sum();
62+
builder.assert_eq(local_cols.sum, sum);
63+
}
64+
}
65+
```
66+
Notes:
67+
- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice.
68+
- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait.
69+
70+
The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation.
71+
72+
The `ColsRef` macro supports more than just variable-length array fields.
73+
The field types can also be:
74+
- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]`
75+
- any type that derives `ColsRef` via `#[derive(ColsRef)]`
76+
- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow`
77+
78+
Note that we currently do not support arrays of types that derive `ColsRef`.
79+
80+
## Specification
81+
82+
Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`.
83+
- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait
84+
85+
The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows:
86+
- type `T` becomes `&T`
87+
- type `[T; LEN]` becomes `&ArrayView1<T>` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig`
88+
- the `ExampleColsRef::from` method will correctly infer the length of the array from the config
89+
- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively
90+
- one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type
91+
- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type
92+
- nested arrays of `U` become `&ArrayViewX<U>` where `X` is the number of dimensions in the nested array type
93+
- `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]`
94+
- the `ArrayViewX` type provides a `X`-dimensional view into the row slice
95+
96+
The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references.
97+
- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields.
98+
- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively.
99+
100+
Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented:
101+
```rust
102+
// Takes a slice of the correct length and returns an instance of the columns struct.
103+
pub const fn from<C: ExampleConfig>(slice: &[T]) -> Self;
104+
// Returns the number of cells in the struct
105+
pub const fn width<C: ExampleConfig>() -> usize;
106+
```
107+
Note that the `width` method on both structs returns the same value.
108+
109+
Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`.
110+
This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`.
111+
112+
See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types.

crates/circuits/primitives/derive/src/cols_ref.rs renamed to crates/circuits/primitives/derive/src/cols_ref/mod.rs

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
extern crate proc_macro;
2+
13
use itertools::Itertools;
24
use quote::{format_ident, quote};
35
use syn::{parse_quote, DeriveInput};
@@ -36,16 +38,21 @@ pub fn cols_ref_impl(
3638

3739
match data {
3840
syn::Data::Struct(data_struct) => {
41+
// Process the fields of the struct, transforming the types for use in ColsRef struct
3942
let const_field_infos: Vec<FieldInfo> = data_struct
4043
.fields
4144
.iter()
4245
.map(|f| get_const_cols_ref_fields(f, generic_type, &const_generics))
4346
.collect::<Result<Vec<_>, String>>()
4447
.map_err(|e| format!("Failed to process fields. {}", e))?;
4548

49+
// The ColsRef struct is named by appending `Ref` to the struct name
4650
let const_cols_ref_name = syn::Ident::new(&format!("{}Ref", ident), ident.span());
51+
52+
// the args to the `from` method will be different for the ColsRef and ColsRefMut structs
4753
let from_args = quote! { slice: &'a [#generic_type] };
4854

55+
// Package all the necessary information to generate the ColsRef struct
4956
let struct_info = StructInfo {
5057
name: const_cols_ref_name,
5158
vis: vis.clone(),
@@ -56,20 +63,27 @@ pub fn cols_ref_impl(
5663
derive_clone: true,
5764
};
5865

66+
// Generate the ColsRef struct
5967
let const_cols_ref_struct = make_struct(struct_info.clone(), &config);
6068

69+
// Generate the `from_mut` method for the ColsRef struct
6170
let from_mut_impl = make_from_mut(struct_info, &config)?;
6271

72+
// Process the fields of the struct, transforming the types for use in ColsRefMut struct
6373
let mut_field_infos: Vec<FieldInfo> = data_struct
6474
.fields
6575
.iter()
6676
.map(|f| get_mut_cols_ref_fields(f, generic_type, &const_generics))
6777
.collect::<Result<Vec<_>, String>>()
6878
.map_err(|e| format!("Failed to process fields. {}", e))?;
6979

80+
// The ColsRefMut struct is named by appending `RefMut` to the struct name
7081
let mut_cols_ref_name = syn::Ident::new(&format!("{}RefMut", ident), ident.span());
82+
83+
// the args to the `from` method will be different for the ColsRef and ColsRefMut structs
7184
let from_args = quote! { slice: &'a mut [#generic_type] };
7285

86+
// Package all the necessary information to generate the ColsRefMut struct
7387
let struct_info = StructInfo {
7488
name: mut_cols_ref_name,
7589
vis,
@@ -80,6 +94,7 @@ pub fn cols_ref_impl(
8094
derive_clone: false,
8195
};
8296

97+
// Generate the ColsRefMut struct
8398
let mut_cols_ref_struct = make_struct(struct_info, &config);
8499

85100
Ok(quote! {
@@ -103,6 +118,12 @@ struct StructInfo {
103118
derive_clone: bool,
104119
}
105120

121+
// Generate the ColsRef and ColsRefMut structs, depending on the value of `struct_info`
122+
// This function is meant to reduce code duplication between the code needed to generate the two structs
123+
// Notable differences between the two structs are:
124+
// - the types of the fields
125+
// - ColsRef derives Clone, but ColsRefMut cannot (since it stores mutable references)
126+
// - the `from` method parameter is a reference to a slice for ColsRef and a mutable reference to a slice for ColsRefMut
106127
fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream {
107128
let StructInfo {
108129
name,
@@ -147,7 +168,6 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac
147168
}
148169
}
149170

150-
// TODO: make this return the size in bytes (to support fields of constant size)
151171
// returns number of cells in the struct (where each cell has type T)
152172
pub const fn width<C: #config>() -> usize {
153173
0 #( + #length_exprs )*
@@ -156,6 +176,7 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac
156176
}
157177
}
158178

179+
// Generate the `from_mut` method for the ColsRef struct
159180
fn make_from_mut(
160181
struct_info: StructInfo,
161182
config: &proc_macro2::Ident,
@@ -183,21 +204,25 @@ fn make_from_mut(
183204
let is_array = matches!(f.ty, syn::Type::Array(_));
184205

185206
if is_array {
207+
// calling view() on ArrayViewMut returns an ArrayView
186208
Ok(quote! {
187209
other.#ident.view()
188210
})
189211
} else if derives_aligned_borrow {
212+
// implicitly converts a mutable reference to an immutable reference, so leave the field value unchanged
190213
Ok(quote! {
191214
other.#ident
192215
})
193216
} else if is_columns_struct(&f.ty) {
194217
// lifetime 'b is used in from_mut to allow more flexible lifetime of return value
195218
let cols_ref_type =
196219
get_const_cols_ref_type(&f.ty, &generic_type, parse_quote! { 'b });
220+
// Recursively call `from_mut` on the ColsRef field
197221
Ok(quote! {
198222
<#cols_ref_type>::from_mut::<C>(&other.#ident)
199223
})
200224
} else if is_generic_type(&f.ty, &generic_type) {
225+
// implicitly converts a mutable reference to an immutable reference, so leave the field value unchanged
201226
Ok(quote! {
202227
&other.#ident
203228
})
@@ -230,6 +255,8 @@ fn make_from_mut(
230255
})
231256
}
232257

258+
// Information about a field that is used to generate the ColsRef and ColsRefMut structs
259+
// See the `make_struct` function to see how this information is used
233260
#[derive(Debug, Clone)]
234261
struct FieldInfo {
235262
// type for struct definition
@@ -257,13 +284,8 @@ fn get_const_cols_ref_fields(
257284
.iter()
258285
.any(|attr| attr.path().is_ident("aligned_borrow"));
259286

260-
let has_plain_array_attribute = f.attrs.iter().any(|attr| attr.path().is_ident("array"));
261287
let is_array = matches!(f.ty, syn::Type::Array(_));
262288

263-
if has_plain_array_attribute && !is_array {
264-
panic!("field marked with `plain_array` attribute must be an array");
265-
}
266-
267289
if is_array {
268290
let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics);
269291
debug_assert!(
@@ -286,33 +308,7 @@ fn get_const_cols_ref_fields(
286308
})
287309
.collect_vec();
288310

289-
if has_plain_array_attribute {
290-
Err("unsupported currently".to_string())
291-
/*
292-
debug_assert!(
293-
dims.len() == 1,
294-
"field marked with `plain_array` attribute must be a 1D array"
295-
);
296-
297-
let length_expr = quote! {
298-
1 #(* #dim_exprs)*
299-
};
300-
301-
Ok(FieldInfo {
302-
ty: parse_quote! {
303-
& #f.ty
304-
},
305-
length_expr: length_expr.clone(),
306-
prepare_subslice: quote! {
307-
let (#slice_var, slice) = slice.split_at(#length_expr);
308-
let #slice_var = #slice_var.try_into().unwrap();
309-
},
310-
initializer: quote! {
311-
#slice_var
312-
},
313-
})
314-
*/
315-
} else if derives_aligned_borrow {
311+
if derives_aligned_borrow {
316312
let length_expr = quote! {
317313
<#elem_type>::width() #(* #dim_exprs)*
318314
};
@@ -552,6 +548,8 @@ fn get_mut_cols_ref_fields(
552548
}
553549
}
554550

551+
// Helper functions
552+
555553
fn is_columns_struct(ty: &syn::Type) -> bool {
556554
if let syn::Type::Path(type_path) = ty {
557555
type_path
@@ -666,6 +664,8 @@ fn get_elem_type(ty: &syn::Type) -> syn::Type {
666664
}
667665
}
668666

667+
// Get a vector of the dimensions of the array
668+
// Each dimension is either a constant generic or a literal integer value
669669
fn get_dims(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec<Dimension> {
670670
get_dims_impl(ty, const_generics)
671671
.into_iter()

crates/circuits/primitives/derive/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
402402
}
403403
}
404404

405-
#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config, plain_array))]
406-
pub fn cols_ref(input: TokenStream) -> TokenStream {
405+
#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))]
406+
pub fn cols_ref_derive(input: TokenStream) -> TokenStream {
407407
let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput);
408408

409409
let config = derive_input

0 commit comments

Comments
 (0)