diff --git a/kimchi/src/circuits/expr.rs b/kimchi/src/circuits/expr.rs index 4bc6977555..b12b858a67 100644 --- a/kimchi/src/circuits/expr.rs +++ b/kimchi/src/circuits/expr.rs @@ -1921,6 +1921,14 @@ impl From for Expr> { } } +impl Mul for Expr> { + type Output = Expr>; + + fn mul(self, y: F) -> Self::Output { + Expr::Constant(ConstantExpr::Literal(y)) * self + } +} + // // Display // diff --git a/kimchi/src/circuits/gate.rs b/kimchi/src/circuits/gate.rs index aadfe84589..cfa2f756a5 100644 --- a/kimchi/src/circuits/gate.rs +++ b/kimchi/src/circuits/gate.rs @@ -2,7 +2,7 @@ use crate::circuits::{constraints::ConstraintSystem, domains::EvaluationDomains, wires::*}; use ark_ff::bytes::ToBytes; -use ark_ff::{FftField, Field}; +use ark_ff::{FftField, Field, One, Zero}; use ark_poly::{Evaluations as E, Radix2EvaluationDomain as D}; use num_traits::cast::ToPrimitive; use o1_utils::hasher::CryptoDigest; @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use serde_with::serde_as; use std::collections::{hash_map::Entry, HashMap, HashSet}; use std::io::{Result as IoResult, Write}; +use std::ops::Mul; type Evaluations = E>; @@ -68,40 +69,70 @@ pub struct SingleLookup { /// analogously using `joint_combiner`. /// /// This function computes that combined value. -pub fn combine_table_entry<'a, F: Field, I: DoubleEndedIterator>( - joint_combiner: F, - v: I, -) -> F { - v.rev().fold(F::zero(), |acc, x| joint_combiner * acc + x) +pub fn combine_table_entry<'a, F, I>(joint_combiner: F, v: I) -> F +where + F: 'a, // Any references in `F` must have a lifetime longer than `'a`. + F: Zero + One + Clone, + I: DoubleEndedIterator, +{ + v.rev() + .fold(F::zero(), |acc, x| joint_combiner.clone() * acc + x.clone()) } -impl SingleLookup { +impl SingleLookup { /// Evaluate the linear combination specifying the lookup value to a field element. - pub fn evaluate F>(&self, eval: G) -> F { + pub fn evaluate K>(&self, eval: G) -> K + where + K: Zero, + K: Mul, + { self.value .iter() - .fold(F::zero(), |acc, (c, p)| acc + *c * eval(*p)) + .fold(K::zero(), |acc, (c, p)| acc + eval(*p) * *c) } } /// A spec for checking that the given vector belongs to a vector-valued lookup table. #[derive(Clone, Serialize, Deserialize)] -pub struct JointLookup { - pub table_id: usize, - pub entry: Vec>, +pub struct JointLookup { + pub table_id: i32, + pub entry: Vec, } -impl JointLookup { +/// A spec for checking that the given vector belongs to a vector-valued lookup table, where the +/// components of the vector are computed from a linear combination of locally-accessible cells. +pub type JointLookupSpec = JointLookup>; + +impl JointLookup { // TODO: Support multiple tables /// Evaluate the combined value of a joint-lookup. - pub fn evaluate F>(&self, joint_combiner: F, eval: &G) -> F { - let mut res = F::zero(); - let mut c = F::one(); - for s in self.entry.iter() { - res += c * s.evaluate(eval); - c *= joint_combiner; + pub fn evaluate(&self, joint_combiner: F) -> F { + combine_table_entry(joint_combiner, self.entry.iter()) + } +} + +impl JointLookup> { + /// Reduce linear combinations in the lookup entries to a single value, resolving local + /// positions using the given function. + pub fn reduce K>(&self, eval: &G) -> JointLookup + where + K: Zero, + K: Mul, + { + JointLookup { + table_id: self.table_id, + entry: self.entry.iter().map(|s| s.evaluate(eval)).collect(), } - res + } + + /// Evaluate the combined value of a joint-lookup, resolving local positions using the given + /// function. + pub fn evaluate K>(&self, joint_combiner: K, eval: &G) -> K + where + K: Zero + One + Clone, + K: Mul, + { + self.reduce(eval).evaluate(joint_combiner) } } @@ -158,7 +189,7 @@ pub enum GateType { pub struct LookupInfo { /// A single lookup constraint is a vector of lookup constraints to be applied at a row. /// This is a vector of all the kinds of lookup constraints in this configuration. - pub kinds: Vec>>, + pub kinds: Vec>>, /// A map from the kind of gate (and whether it is the current row or next row) to the lookup /// constraint (given as an index into `kinds`) that should be applied there, if any. pub kinds_map: HashMap<(GateType, CurrOrNext), usize>, @@ -170,10 +201,10 @@ pub struct LookupInfo { /// The maximum joint size of any joint lookup in a constraint in `kinds`. This can be computed from `kinds`. pub max_joint_size: u32, /// An empty vector. - empty: Vec>, + empty: Vec>, } -fn max_lookups_per_row(kinds: &[Vec>]) -> usize { +fn max_lookups_per_row(kinds: &[Vec>]) -> usize { kinds.iter().fold(0, |acc, x| std::cmp::max(x.len(), acc)) } @@ -286,7 +317,7 @@ impl LookupInfo { } /// For each row in the circuit, which lookup-constraints should be enforced at that row. - pub fn by_row<'a>(&'a self, gates: &[CircuitGate]) -> Vec<&'a Vec>> { + pub fn by_row<'a>(&'a self, gates: &[CircuitGate]) -> Vec<&'a Vec>> { let mut kinds = vec![&self.empty; gates.len() + 1]; for i in 0..gates.len() { let typ = gates[i].typ; @@ -327,7 +358,7 @@ impl GateType { /// /// See circuits/kimchi/src/polynomials/chacha.rs for an explanation of /// how these work. - pub fn lookup_kinds() -> (Vec>>, Vec) { + pub fn lookup_kinds() -> (Vec>>, Vec) { let curr_row = |column| LocalPosition { row: CurrOrNext::Curr, column, diff --git a/kimchi/src/circuits/polynomials/lookup.rs b/kimchi/src/circuits/polynomials/lookup.rs index b3aea32c0e..f10a328494 100644 --- a/kimchi/src/circuits/polynomials/lookup.rs +++ b/kimchi/src/circuits/polynomials/lookup.rs @@ -122,11 +122,8 @@ use crate::{ circuits::{ - expr::{prologue::*, Column, ConstantExpr, Variable}, - gate::{ - CircuitGate, CurrOrNext, JointLookup, LocalPosition, LookupInfo, LookupsUsed, - SingleLookup, - }, + expr::{prologue::*, Column, ConstantExpr}, + gate::{CircuitGate, CurrOrNext, JointLookupSpec, LocalPosition, LookupInfo, LookupsUsed}, wires::COLUMNS, }, error::ProofError, @@ -142,29 +139,6 @@ use CurrOrNext::*; /// Number of constraints produced by the argument. pub const CONSTRAINTS: u32 = 7; -// TODO: Update for multiple tables -fn single_lookup(s: &SingleLookup) -> E { - // Combine the linear combination. - s.value - .iter() - .map(|(c, pos)| { - E::literal(*c) - * E::Cell(Variable { - col: Column::Witness(pos.column), - row: pos.row, - }) - }) - .fold(E::zero(), |acc, e| acc + e) -} - -fn joint_lookup(j: &JointLookup) -> E { - j.entry - .iter() - .enumerate() - .map(|(i, s)| E::constant(ConstantExpr::JointCombiner.pow(i as u64)) * single_lookup(s)) - .fold(E::zero(), |acc, x| acc + x) -} - struct AdjacentPairs> { prev_second_component: Option, i: I, @@ -356,7 +330,7 @@ pub trait Entry { fn evaluate( p: &Self::Params, - j: &JointLookup, + j: &JointLookupSpec, witness: &[Vec; COLUMNS], row: usize, ) -> Self; @@ -370,7 +344,7 @@ impl Entry for CombinedEntry { fn evaluate( joint_combiner: &F, - j: &JointLookup, + j: &JointLookupSpec, witness: &[Vec; COLUMNS], row: usize, ) -> CombinedEntry { @@ -395,7 +369,7 @@ impl Entry for UncombinedEntry { fn evaluate( _: &(), - j: &JointLookup, + j: &JointLookupSpec, witness: &[Vec; COLUMNS], row: usize, ) -> UncombinedEntry { @@ -661,16 +635,21 @@ pub fn constraints(configuration: &LookupConfiguration, d1: D .collect() }; + let eval = |pos: LocalPosition| witness(pos.column, pos.row); + // This is set up so that on rows that have lookups, chunk will be equal // to the product over all lookups `f` in that row of `gamma + f` // and // on non-lookup rows, will be equal to 1. - let f_term = |spec: &Vec<_>| { + let f_term = |spec: &Vec>| { assert!(spec.len() <= lookup_info.max_per_row); let padding = complements_with_beta_term[lookup_info.max_per_row - spec.len()].clone(); spec.iter() - .map(|j| E::Constant(ConstantExpr::Gamma) + joint_lookup(j)) + .map(|j| { + E::Constant(ConstantExpr::Gamma) + + j.evaluate(E::constant(ConstantExpr::JointCombiner), &eval) + }) .fold(E::Constant(padding), |acc: E, x| acc * x) }; let f_chunk = lookup_info