Skip to content

Commit 316b643

Browse files
Add documentation for ColsRef macro
1 parent 12321d4 commit 316b643

File tree

6 files changed

+503
-36
lines changed

6 files changed

+503
-36
lines changed

crates/circuits/primitives/derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ license.workspace = true
1212
proc-macro = true
1313

1414
[dependencies]
15-
syn = { version = "2.0", features = ["parsing"] }
15+
syn = { version = "2.0", features = ["parsing", "extra-traits"] }
1616
quote = "1.0"
1717
itertools = { workspace = true, default-features = true }
1818
proc-macro2 = "1.0"
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# ColsRef macro
2+
3+
The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes.
4+
5+
Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384).
6+
See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits.
7+
8+
## Overview
9+
10+
As an illustrative example, consider the following columns struct:
11+
```rust
12+
struct ExampleCols<T, const N: usize> {
13+
arr: [T; N],
14+
sum: T,
15+
}
16+
```
17+
Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10.
18+
We can define a trait that stores the config parameters.
19+
```rust
20+
pub trait ExampleConfig {
21+
const N: usize;
22+
}
23+
```
24+
and then implement it for the two different configs.
25+
```rust
26+
pub struct ExampleConfigImplA;
27+
impl ExampleConfig for ExampleConfigImplA {
28+
const N: usize = 5;
29+
}
30+
pub struct ExampleConfigImplB;
31+
impl ExampleConfig for ExampleConfigImplB {
32+
const N: usize = 10;
33+
}
34+
```
35+
Then we can use the `ColsRef` macro like this
36+
```rust
37+
#[derive(ColsRef)]
38+
#[config(ExampleConfig)]
39+
struct ExampleCols<T, const N: usize> {
40+
arr: [T; N],
41+
sum: T,
42+
}
43+
```
44+
which will generate a columns struct that uses references to the fields.
45+
```rust
46+
struct ExampleColsRef<'a, T, const N: usize> {
47+
arr: &'a [T; N],
48+
sum: &'a T,
49+
}
50+
```
51+
The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct.
52+
The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct.
53+
54+
So, the constraint generation code can be written as
55+
```rust
56+
impl<AB: InteractionBuilder, C: ExampleConfig> Air<AB> for ExampleAir<C> {
57+
fn eval(&self, builder: &mut AB) {
58+
let main = builder.main();
59+
let (local, _) = (main.row_slice(0), main.row_slice(1));
60+
let local_cols = ExampleColsRef::<AB::Var>::from::<C>(&local[..C::N + 1]);
61+
let sum = local_cols.arr.iter().sum();
62+
builder.assert_eq(local_cols.sum, sum);
63+
}
64+
}
65+
```
66+
Notes:
67+
- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice.
68+
- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait.
69+
70+
The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation.
71+
72+
The `ColsRef` macro supports more than just variable-length array fields.
73+
The field types can also be:
74+
- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]`
75+
- any type that derives `ColsRef` via `#[derive(ColsRef)]`
76+
- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow`
77+
78+
Note that we currently do not support arrays of types that derive `ColsRef`.
79+
80+
## Specification
81+
82+
Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`.
83+
- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait
84+
85+
The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows:
86+
- type `T` becomes `&T`
87+
- type `[T; LEN]` becomes `&ArrayView1<T>` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig`
88+
- the `ExampleColsRef::from` method will correctly infer the length of the array from the config
89+
- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively
90+
- one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type
91+
- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type
92+
- nested arrays of `U` become `&ArrayViewX<U>` where `X` is the number of dimensions in the nested array type
93+
- `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]`
94+
- the `ArrayViewX` type provides a `X`-dimensional view into the row slice
95+
96+
The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references.
97+
- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields.
98+
- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively.
99+
100+
Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented:
101+
```rust
102+
// Takes a slice of the correct length and returns an instance of the columns struct.
103+
pub const fn from<C: ExampleConfig>(slice: &[T]) -> Self;
104+
// Returns the number of cells in the struct
105+
pub const fn width<C: ExampleConfig>() -> usize;
106+
```
107+
Note that the `width` method on both structs returns the same value.
108+
109+
Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`.
110+
This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`.
111+
112+
See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types.

crates/circuits/primitives/derive/src/cols_ref.rs renamed to crates/circuits/primitives/derive/src/cols_ref/mod.rs

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
extern crate proc_macro;
2+
13
use itertools::Itertools;
24
use quote::{format_ident, quote};
35
use syn::{parse_quote, DeriveInput};
@@ -147,7 +149,6 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac
147149
}
148150
}
149151

150-
// TODO: make this return the size in bytes (to support fields of constant size)
151152
// returns number of cells in the struct (where each cell has type T)
152153
pub const fn width<C: #config>() -> usize {
153154
0 #( + #length_exprs )*
@@ -257,13 +258,8 @@ fn get_const_cols_ref_fields(
257258
.iter()
258259
.any(|attr| attr.path().is_ident("aligned_borrow"));
259260

260-
let has_plain_array_attribute = f.attrs.iter().any(|attr| attr.path().is_ident("array"));
261261
let is_array = matches!(f.ty, syn::Type::Array(_));
262262

263-
if has_plain_array_attribute && !is_array {
264-
panic!("field marked with `plain_array` attribute must be an array");
265-
}
266-
267263
if is_array {
268264
let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics);
269265
debug_assert!(
@@ -286,33 +282,7 @@ fn get_const_cols_ref_fields(
286282
})
287283
.collect_vec();
288284

289-
if has_plain_array_attribute {
290-
Err("unsupported currently".to_string())
291-
/*
292-
debug_assert!(
293-
dims.len() == 1,
294-
"field marked with `plain_array` attribute must be a 1D array"
295-
);
296-
297-
let length_expr = quote! {
298-
1 #(* #dim_exprs)*
299-
};
300-
301-
Ok(FieldInfo {
302-
ty: parse_quote! {
303-
& #f.ty
304-
},
305-
length_expr: length_expr.clone(),
306-
prepare_subslice: quote! {
307-
let (#slice_var, slice) = slice.split_at(#length_expr);
308-
let #slice_var = #slice_var.try_into().unwrap();
309-
},
310-
initializer: quote! {
311-
#slice_var
312-
},
313-
})
314-
*/
315-
} else if derives_aligned_borrow {
285+
if derives_aligned_borrow {
316286
let length_expr = quote! {
317287
<#elem_type>::width() #(* #dim_exprs)*
318288
};

crates/circuits/primitives/derive/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
402402
}
403403
}
404404

405-
#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config, plain_array))]
406-
pub fn cols_ref(input: TokenStream) -> TokenStream {
405+
#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))]
406+
pub fn cols_ref_derive(input: TokenStream) -> TokenStream {
407407
let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput);
408408

409409
let config = derive_input
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use openvm_circuit_primitives_derive::ColsRef;
2+
3+
pub trait ExampleConfig {
4+
const N: usize;
5+
}
6+
pub struct ExampleConfigImplA;
7+
impl ExampleConfig for ExampleConfigImplA {
8+
const N: usize = 5;
9+
}
10+
pub struct ExampleConfigImplB;
11+
impl ExampleConfig for ExampleConfigImplB {
12+
const N: usize = 10;
13+
}
14+
15+
#[allow(dead_code)]
16+
#[derive(ColsRef)]
17+
#[config(ExampleConfig)]
18+
struct ExampleCols<T, const N: usize> {
19+
arr: [T; N],
20+
sum: T,
21+
}
22+
23+
#[test]
24+
fn example() {
25+
let input = [1, 2, 3, 4, 5, 15];
26+
let test: ExampleColsRef<u32> = ExampleColsRef::from::<ExampleConfigImplA>(&input);
27+
println!("{}, {}", test.arr, test.sum);
28+
}
29+
30+
/*
31+
* For reference, this is what the ColsRef macro expands to.
32+
* The `cargo expand` tool is helpful for understanding how the ColsRef macro works.
33+
* See https://github.com/dtolnay/cargo-expand
34+
35+
#[derive(Debug, Clone)]
36+
struct ExampleColsRef<'a, T> {
37+
pub arr: ndarray::ArrayView1<'a, T>,
38+
pub sum: &'a T,
39+
}
40+
41+
impl<'a, T> ExampleColsRef<'a, T> {
42+
pub fn from<C: ExampleConfig>(slice: &'a [T]) -> Self {
43+
let (arr_slice, slice) = slice.split_at(1 * C::N);
44+
let arr_slice = ndarray::ArrayView1::from_shape((C::N), arr_slice).unwrap();
45+
let sum_length = 1;
46+
let (sum_slice, slice) = slice.split_at(sum_length);
47+
Self {
48+
arr: arr_slice,
49+
sum: &sum_slice[0],
50+
}
51+
}
52+
pub const fn width<C: ExampleConfig>() -> usize {
53+
0 + 1 * C::N + 1
54+
}
55+
}
56+
57+
impl<'b, T> ExampleColsRef<'b, T> {
58+
pub fn from_mut<'a, C: ExampleConfig>(other: &'b ExampleColsRefMut<'a, T>) -> Self {
59+
Self {
60+
arr: other.arr.view(),
61+
sum: &other.sum,
62+
}
63+
}
64+
}
65+
66+
#[derive(Debug)]
67+
struct ExampleColsRefMut<'a, T> {
68+
pub arr: ndarray::ArrayViewMut1<'a, T>,
69+
pub sum: &'a mut T,
70+
}
71+
72+
impl<'a, T> ExampleColsRefMut<'a, T> {
73+
pub fn from<C: ExampleConfig>(slice: &'a mut [T]) -> Self {
74+
let (arr_slice, slice) = slice.split_at_mut(1 * C::N);
75+
let arr_slice = ndarray::ArrayViewMut1::from_shape((C::N), arr_slice).unwrap();
76+
let sum_length = 1;
77+
let (sum_slice, slice) = slice.split_at_mut(sum_length);
78+
Self {
79+
arr: arr_slice,
80+
sum: &mut sum_slice[0],
81+
}
82+
}
83+
pub const fn width<C: ExampleConfig>() -> usize {
84+
0 + 1 * C::N + 1
85+
}
86+
}
87+
*/

0 commit comments

Comments
 (0)