Skip to content

Commit e1ef7ec

Browse files
authored
Merge pull request #433 from o1-labs/feature/joint-lookup-generalised
Use a single implementation for `combine_table_entry`
2 parents 2079e17 + cbdfc5a commit e1ef7ec

File tree

3 files changed

+76
-58
lines changed

3 files changed

+76
-58
lines changed

kimchi/src/circuits/expr.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,14 @@ impl<F: Field> From<u64> for Expr<ConstantExpr<F>> {
19211921
}
19221922
}
19231923

1924+
impl<F: Field> Mul<F> for Expr<ConstantExpr<F>> {
1925+
type Output = Expr<ConstantExpr<F>>;
1926+
1927+
fn mul(self, y: F) -> Self::Output {
1928+
Expr::Constant(ConstantExpr::Literal(y)) * self
1929+
}
1930+
}
1931+
19241932
//
19251933
// Display
19261934
//

kimchi/src/circuits/gate.rs

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
33
use crate::circuits::{constraints::ConstraintSystem, domains::EvaluationDomains, wires::*};
44
use ark_ff::bytes::ToBytes;
5-
use ark_ff::{FftField, Field};
5+
use ark_ff::{FftField, Field, One, Zero};
66
use ark_poly::{Evaluations as E, Radix2EvaluationDomain as D};
77
use num_traits::cast::ToPrimitive;
88
use o1_utils::hasher::CryptoDigest;
99
use serde::{Deserialize, Serialize};
1010
use serde_with::serde_as;
1111
use std::collections::{hash_map::Entry, HashMap, HashSet};
1212
use std::io::{Result as IoResult, Write};
13+
use std::ops::Mul;
1314

1415
type Evaluations<Field> = E<Field, D<Field>>;
1516

@@ -68,40 +69,70 @@ pub struct SingleLookup<F> {
6869
/// analogously using `joint_combiner`.
6970
///
7071
/// This function computes that combined value.
71-
pub fn combine_table_entry<'a, F: Field, I: DoubleEndedIterator<Item = &'a F>>(
72-
joint_combiner: F,
73-
v: I,
74-
) -> F {
75-
v.rev().fold(F::zero(), |acc, x| joint_combiner * acc + x)
72+
pub fn combine_table_entry<'a, F, I>(joint_combiner: F, v: I) -> F
73+
where
74+
F: 'a, // Any references in `F` must have a lifetime longer than `'a`.
75+
F: Zero + One + Clone,
76+
I: DoubleEndedIterator<Item = &'a F>,
77+
{
78+
v.rev()
79+
.fold(F::zero(), |acc, x| joint_combiner.clone() * acc + x.clone())
7680
}
7781

78-
impl<F: Field> SingleLookup<F> {
82+
impl<F: Copy> SingleLookup<F> {
7983
/// Evaluate the linear combination specifying the lookup value to a field element.
80-
pub fn evaluate<G: Fn(LocalPosition) -> F>(&self, eval: G) -> F {
84+
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, eval: G) -> K
85+
where
86+
K: Zero,
87+
K: Mul<F, Output = K>,
88+
{
8189
self.value
8290
.iter()
83-
.fold(F::zero(), |acc, (c, p)| acc + *c * eval(*p))
91+
.fold(K::zero(), |acc, (c, p)| acc + eval(*p) * *c)
8492
}
8593
}
8694

8795
/// A spec for checking that the given vector belongs to a vector-valued lookup table.
8896
#[derive(Clone, Serialize, Deserialize)]
89-
pub struct JointLookup<F> {
90-
pub table_id: usize,
91-
pub entry: Vec<SingleLookup<F>>,
97+
pub struct JointLookup<SingleLookup> {
98+
pub table_id: i32,
99+
pub entry: Vec<SingleLookup>,
92100
}
93101

94-
impl<F: Field> JointLookup<F> {
102+
/// A spec for checking that the given vector belongs to a vector-valued lookup table, where the
103+
/// components of the vector are computed from a linear combination of locally-accessible cells.
104+
pub type JointLookupSpec<F> = JointLookup<SingleLookup<F>>;
105+
106+
impl<F: Zero + One + Clone> JointLookup<F> {
95107
// TODO: Support multiple tables
96108
/// Evaluate the combined value of a joint-lookup.
97-
pub fn evaluate<G: Fn(LocalPosition) -> F>(&self, joint_combiner: F, eval: &G) -> F {
98-
let mut res = F::zero();
99-
let mut c = F::one();
100-
for s in self.entry.iter() {
101-
res += c * s.evaluate(eval);
102-
c *= joint_combiner;
109+
pub fn evaluate(&self, joint_combiner: F) -> F {
110+
combine_table_entry(joint_combiner, self.entry.iter())
111+
}
112+
}
113+
114+
impl<F: Copy> JointLookup<SingleLookup<F>> {
115+
/// Reduce linear combinations in the lookup entries to a single value, resolving local
116+
/// positions using the given function.
117+
pub fn reduce<K, G: Fn(LocalPosition) -> K>(&self, eval: &G) -> JointLookup<K>
118+
where
119+
K: Zero,
120+
K: Mul<F, Output = K>,
121+
{
122+
JointLookup {
123+
table_id: self.table_id,
124+
entry: self.entry.iter().map(|s| s.evaluate(eval)).collect(),
103125
}
104-
res
126+
}
127+
128+
/// Evaluate the combined value of a joint-lookup, resolving local positions using the given
129+
/// function.
130+
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, joint_combiner: K, eval: &G) -> K
131+
where
132+
K: Zero + One + Clone,
133+
K: Mul<F, Output = K>,
134+
{
135+
self.reduce(eval).evaluate(joint_combiner)
105136
}
106137
}
107138

@@ -158,7 +189,7 @@ pub enum GateType {
158189
pub struct LookupInfo<F> {
159190
/// A single lookup constraint is a vector of lookup constraints to be applied at a row.
160191
/// This is a vector of all the kinds of lookup constraints in this configuration.
161-
pub kinds: Vec<Vec<JointLookup<F>>>,
192+
pub kinds: Vec<Vec<JointLookupSpec<F>>>,
162193
/// A map from the kind of gate (and whether it is the current row or next row) to the lookup
163194
/// constraint (given as an index into `kinds`) that should be applied there, if any.
164195
pub kinds_map: HashMap<(GateType, CurrOrNext), usize>,
@@ -170,10 +201,10 @@ pub struct LookupInfo<F> {
170201
/// The maximum joint size of any joint lookup in a constraint in `kinds`. This can be computed from `kinds`.
171202
pub max_joint_size: u32,
172203
/// An empty vector.
173-
empty: Vec<JointLookup<F>>,
204+
empty: Vec<JointLookupSpec<F>>,
174205
}
175206

176-
fn max_lookups_per_row<F>(kinds: &[Vec<JointLookup<F>>]) -> usize {
207+
fn max_lookups_per_row<F>(kinds: &[Vec<JointLookupSpec<F>>]) -> usize {
177208
kinds.iter().fold(0, |acc, x| std::cmp::max(x.len(), acc))
178209
}
179210

@@ -286,7 +317,7 @@ impl<F: FftField> LookupInfo<F> {
286317
}
287318

288319
/// For each row in the circuit, which lookup-constraints should be enforced at that row.
289-
pub fn by_row<'a>(&'a self, gates: &[CircuitGate<F>]) -> Vec<&'a Vec<JointLookup<F>>> {
320+
pub fn by_row<'a>(&'a self, gates: &[CircuitGate<F>]) -> Vec<&'a Vec<JointLookupSpec<F>>> {
290321
let mut kinds = vec![&self.empty; gates.len() + 1];
291322
for i in 0..gates.len() {
292323
let typ = gates[i].typ;
@@ -327,7 +358,7 @@ impl GateType {
327358
///
328359
/// See circuits/kimchi/src/polynomials/chacha.rs for an explanation of
329360
/// how these work.
330-
pub fn lookup_kinds<F: Field>() -> (Vec<Vec<JointLookup<F>>>, Vec<GatesLookupSpec>) {
361+
pub fn lookup_kinds<F: Field>() -> (Vec<Vec<JointLookupSpec<F>>>, Vec<GatesLookupSpec>) {
331362
let curr_row = |column| LocalPosition {
332363
row: CurrOrNext::Curr,
333364
column,

kimchi/src/circuits/polynomials/lookup.rs

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,8 @@
122122
123123
use crate::{
124124
circuits::{
125-
expr::{prologue::*, Column, ConstantExpr, Variable},
126-
gate::{
127-
CircuitGate, CurrOrNext, JointLookup, LocalPosition, LookupInfo, LookupsUsed,
128-
SingleLookup,
129-
},
125+
expr::{prologue::*, Column, ConstantExpr},
126+
gate::{CircuitGate, CurrOrNext, JointLookupSpec, LocalPosition, LookupInfo, LookupsUsed},
130127
wires::COLUMNS,
131128
},
132129
error::ProofError,
@@ -142,29 +139,6 @@ use CurrOrNext::*;
142139
/// Number of constraints produced by the argument.
143140
pub const CONSTRAINTS: u32 = 7;
144141

145-
// TODO: Update for multiple tables
146-
fn single_lookup<F: FftField>(s: &SingleLookup<F>) -> E<F> {
147-
// Combine the linear combination.
148-
s.value
149-
.iter()
150-
.map(|(c, pos)| {
151-
E::literal(*c)
152-
* E::Cell(Variable {
153-
col: Column::Witness(pos.column),
154-
row: pos.row,
155-
})
156-
})
157-
.fold(E::zero(), |acc, e| acc + e)
158-
}
159-
160-
fn joint_lookup<F: FftField>(j: &JointLookup<F>) -> E<F> {
161-
j.entry
162-
.iter()
163-
.enumerate()
164-
.map(|(i, s)| E::constant(ConstantExpr::JointCombiner.pow(i as u64)) * single_lookup(s))
165-
.fold(E::zero(), |acc, x| acc + x)
166-
}
167-
168142
struct AdjacentPairs<A, I: Iterator<Item = A>> {
169143
prev_second_component: Option<A>,
170144
i: I,
@@ -356,7 +330,7 @@ pub trait Entry {
356330

357331
fn evaluate(
358332
p: &Self::Params,
359-
j: &JointLookup<Self::Field>,
333+
j: &JointLookupSpec<Self::Field>,
360334
witness: &[Vec<Self::Field>; COLUMNS],
361335
row: usize,
362336
) -> Self;
@@ -370,7 +344,7 @@ impl<F: Field> Entry for CombinedEntry<F> {
370344

371345
fn evaluate(
372346
joint_combiner: &F,
373-
j: &JointLookup<F>,
347+
j: &JointLookupSpec<F>,
374348
witness: &[Vec<F>; COLUMNS],
375349
row: usize,
376350
) -> CombinedEntry<F> {
@@ -395,7 +369,7 @@ impl<F: Field> Entry for UncombinedEntry<F> {
395369

396370
fn evaluate(
397371
_: &(),
398-
j: &JointLookup<F>,
372+
j: &JointLookupSpec<F>,
399373
witness: &[Vec<F>; COLUMNS],
400374
row: usize,
401375
) -> UncombinedEntry<F> {
@@ -661,16 +635,21 @@ pub fn constraints<F: FftField>(configuration: &LookupConfiguration<F>, d1: D<F>
661635
.collect()
662636
};
663637

638+
let eval = |pos: LocalPosition| witness(pos.column, pos.row);
639+
664640
// This is set up so that on rows that have lookups, chunk will be equal
665641
// to the product over all lookups `f` in that row of `gamma + f`
666642
// and
667643
// on non-lookup rows, will be equal to 1.
668-
let f_term = |spec: &Vec<_>| {
644+
let f_term = |spec: &Vec<JointLookupSpec<_>>| {
669645
assert!(spec.len() <= lookup_info.max_per_row);
670646
let padding = complements_with_beta_term[lookup_info.max_per_row - spec.len()].clone();
671647

672648
spec.iter()
673-
.map(|j| E::Constant(ConstantExpr::Gamma) + joint_lookup(j))
649+
.map(|j| {
650+
E::Constant(ConstantExpr::Gamma)
651+
+ j.evaluate(E::constant(ConstantExpr::JointCombiner), &eval)
652+
})
674653
.fold(E::Constant(padding), |acc: E<F>, x| acc * x)
675654
};
676655
let f_chunk = lookup_info

0 commit comments

Comments
 (0)