Skip to content

Use a single implementation for combine_table_entry #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 18, 2022
8 changes: 8 additions & 0 deletions kimchi/src/circuits/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,14 @@ impl<F: Field> From<u64> for Expr<ConstantExpr<F>> {
}
}

impl<F: Field> Mul<F> for Expr<ConstantExpr<F>> {
type Output = Expr<ConstantExpr<F>>;

fn mul(self, y: F) -> Self::Output {
Expr::Constant(ConstantExpr::Literal(y)) * self
}
}

//
// Display
//
Expand Down
81 changes: 56 additions & 25 deletions kimchi/src/circuits/gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

use crate::circuits::{constraints::ConstraintSystem, domains::EvaluationDomains, wires::*};
use ark_ff::bytes::ToBytes;
use ark_ff::{FftField, Field};
use ark_ff::{FftField, Field, One, Zero};
use ark_poly::{Evaluations as E, Radix2EvaluationDomain as D};
use num_traits::cast::ToPrimitive;
use o1_utils::hasher::CryptoDigest;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::collections::{hash_map::Entry, HashMap, HashSet};
use std::io::{Result as IoResult, Write};
use std::ops::Mul;

type Evaluations<Field> = E<Field, D<Field>>;

Expand Down Expand Up @@ -68,40 +69,70 @@ pub struct SingleLookup<F> {
/// analogously using `joint_combiner`.
///
/// This function computes that combined value.
pub fn combine_table_entry<'a, F: Field, I: DoubleEndedIterator<Item = &'a F>>(
joint_combiner: F,
v: I,
) -> F {
v.rev().fold(F::zero(), |acc, x| joint_combiner * acc + x)
pub fn combine_table_entry<'a, F, I>(joint_combiner: F, v: I) -> F
where
F: 'a, // Any references in `F` must have a lifetime longer than `'a`.
F: Zero + One + Clone,
I: DoubleEndedIterator<Item = &'a F>,
{
v.rev()
.fold(F::zero(), |acc, x| joint_combiner.clone() * acc + x.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you add Copy instead of Clone you can remove the clone here (Field is Copy)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F is copy but Expr isn't (and we use JointCombiner rather than a field element there).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see. My tolerance for generics is usually pretty low ^^ I'm wondering how we can make this clearer. Maybe by changing F with F_or_Expr or something. Food for thought

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then creating an actual trait F_or_Expr : Zero + One + ... and implement it on what we need

}

impl<F: Field> SingleLookup<F> {
impl<F: Copy> SingleLookup<F> {
/// Evaluate the linear combination specifying the lookup value to a field element.
pub fn evaluate<G: Fn(LocalPosition) -> F>(&self, eval: G) -> F {
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, eval: G) -> K
where
K: Zero,
K: Mul<F, Output = K>,
{
self.value
.iter()
.fold(F::zero(), |acc, (c, p)| acc + *c * eval(*p))
.fold(K::zero(), |acc, (c, p)| acc + eval(*p) * *c)
}
}

/// A spec for checking that the given vector belongs to a vector-valued lookup table.
#[derive(Clone, Serialize, Deserialize)]
pub struct JointLookup<F> {
pub table_id: usize,
pub entry: Vec<SingleLookup<F>>,
pub struct JointLookup<SingleLookup> {
pub table_id: i32,
pub entry: Vec<SingleLookup>,
}

impl<F: Field> JointLookup<F> {
/// A spec for checking that the given vector belongs to a vector-valued lookup table, where the
/// components of the vector are computed from a linear combination of locally-accessible cells.
pub type JointLookupSpec<F> = JointLookup<SingleLookup<F>>;

impl<F: Zero + One + Clone> JointLookup<F> {
// TODO: Support multiple tables
/// Evaluate the combined value of a joint-lookup.
pub fn evaluate<G: Fn(LocalPosition) -> F>(&self, joint_combiner: F, eval: &G) -> F {
let mut res = F::zero();
let mut c = F::one();
for s in self.entry.iter() {
res += c * s.evaluate(eval);
c *= joint_combiner;
pub fn evaluate(&self, joint_combiner: F) -> F {
combine_table_entry(joint_combiner, self.entry.iter())
}
}

impl<F: Copy> JointLookup<SingleLookup<F>> {
/// Reduce linear combinations in the lookup entries to a single value, resolving local
/// positions using the given function.
pub fn reduce<K, G: Fn(LocalPosition) -> K>(&self, eval: &G) -> JointLookup<K>
where
K: Zero,
K: Mul<F, Output = K>,
{
JointLookup {
table_id: self.table_id,
entry: self.entry.iter().map(|s| s.evaluate(eval)).collect(),
}
res
}

/// Evaluate the combined value of a joint-lookup, resolving local positions using the given
/// function.
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, joint_combiner: K, eval: &G) -> K
where
K: Zero + One + Clone,
K: Mul<F, Output = K>,
{
self.reduce(eval).evaluate(joint_combiner)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly I find this function really hard to parse.

  • we have a JointLookup<SingleLookup<F>> for some F: Zero + One + Add + Mul + Clone
  • we want a K for some K: Zero + One + Add + Mul + Mul<F, Output = K> + Clone
  • in the mean time, we have some eval function that takes a LocalPosition and return a K

it's just too generic to really grasp what is really going on IMO. Is there a way to make this clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can just have code duplication for the different instantiations of these functions, this would make the code much clearer imo and we wouldn't need to pass a closure around (IIUC)

}
}

Expand Down Expand Up @@ -158,7 +189,7 @@ pub enum GateType {
pub struct LookupInfo<F> {
/// A single lookup constraint is a vector of lookup constraints to be applied at a row.
/// This is a vector of all the kinds of lookup constraints in this configuration.
pub kinds: Vec<Vec<JointLookup<F>>>,
pub kinds: Vec<Vec<JointLookupSpec<F>>>,
/// A map from the kind of gate (and whether it is the current row or next row) to the lookup
/// constraint (given as an index into `kinds`) that should be applied there, if any.
pub kinds_map: HashMap<(GateType, CurrOrNext), usize>,
Expand All @@ -170,10 +201,10 @@ pub struct LookupInfo<F> {
/// The maximum joint size of any joint lookup in a constraint in `kinds`. This can be computed from `kinds`.
pub max_joint_size: u32,
/// An empty vector.
empty: Vec<JointLookup<F>>,
empty: Vec<JointLookupSpec<F>>,
}

fn max_lookups_per_row<F>(kinds: &[Vec<JointLookup<F>>]) -> usize {
fn max_lookups_per_row<F>(kinds: &[Vec<JointLookupSpec<F>>]) -> usize {
kinds.iter().fold(0, |acc, x| std::cmp::max(x.len(), acc))
}

Expand Down Expand Up @@ -286,7 +317,7 @@ impl<F: FftField> LookupInfo<F> {
}

/// For each row in the circuit, which lookup-constraints should be enforced at that row.
pub fn by_row<'a>(&'a self, gates: &[CircuitGate<F>]) -> Vec<&'a Vec<JointLookup<F>>> {
pub fn by_row<'a>(&'a self, gates: &[CircuitGate<F>]) -> Vec<&'a Vec<JointLookupSpec<F>>> {
let mut kinds = vec![&self.empty; gates.len() + 1];
for i in 0..gates.len() {
let typ = gates[i].typ;
Expand Down Expand Up @@ -327,7 +358,7 @@ impl GateType {
///
/// See circuits/kimchi/src/polynomials/chacha.rs for an explanation of
/// how these work.
pub fn lookup_kinds<F: Field>() -> (Vec<Vec<JointLookup<F>>>, Vec<GatesLookupSpec>) {
pub fn lookup_kinds<F: Field>() -> (Vec<Vec<JointLookupSpec<F>>>, Vec<GatesLookupSpec>) {
let curr_row = |column| LocalPosition {
row: CurrOrNext::Curr,
column,
Expand Down
45 changes: 12 additions & 33 deletions kimchi/src/circuits/polynomials/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,8 @@

use crate::{
circuits::{
expr::{prologue::*, Column, ConstantExpr, Variable},
gate::{
CircuitGate, CurrOrNext, JointLookup, LocalPosition, LookupInfo, LookupsUsed,
SingleLookup,
},
expr::{prologue::*, Column, ConstantExpr},
gate::{CircuitGate, CurrOrNext, JointLookupSpec, LocalPosition, LookupInfo, LookupsUsed},
wires::COLUMNS,
},
error::ProofError,
Expand All @@ -142,29 +139,6 @@ use CurrOrNext::*;
/// Number of constraints produced by the argument.
pub const CONSTRAINTS: u32 = 7;

// TODO: Update for multiple tables
fn single_lookup<F: FftField>(s: &SingleLookup<F>) -> E<F> {
// Combine the linear combination.
s.value
.iter()
.map(|(c, pos)| {
E::literal(*c)
* E::Cell(Variable {
col: Column::Witness(pos.column),
row: pos.row,
})
})
.fold(E::zero(), |acc, e| acc + e)
}

fn joint_lookup<F: FftField>(j: &JointLookup<F>) -> E<F> {
j.entry
.iter()
.enumerate()
.map(|(i, s)| E::constant(ConstantExpr::JointCombiner.pow(i as u64)) * single_lookup(s))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if I understand correctly this is replaced by the logic in combine_table_entry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, yes

.fold(E::zero(), |acc, x| acc + x)
}

struct AdjacentPairs<A, I: Iterator<Item = A>> {
prev_second_component: Option<A>,
i: I,
Expand Down Expand Up @@ -356,7 +330,7 @@ pub trait Entry {

fn evaluate(
p: &Self::Params,
j: &JointLookup<Self::Field>,
j: &JointLookupSpec<Self::Field>,
witness: &[Vec<Self::Field>; COLUMNS],
row: usize,
) -> Self;
Expand All @@ -370,7 +344,7 @@ impl<F: Field> Entry for CombinedEntry<F> {

fn evaluate(
joint_combiner: &F,
j: &JointLookup<F>,
j: &JointLookupSpec<F>,
witness: &[Vec<F>; COLUMNS],
row: usize,
) -> CombinedEntry<F> {
Expand All @@ -395,7 +369,7 @@ impl<F: Field> Entry for UncombinedEntry<F> {

fn evaluate(
_: &(),
j: &JointLookup<F>,
j: &JointLookupSpec<F>,
witness: &[Vec<F>; COLUMNS],
row: usize,
) -> UncombinedEntry<F> {
Expand Down Expand Up @@ -661,16 +635,21 @@ pub fn constraints<F: FftField>(configuration: &LookupConfiguration<F>, d1: D<F>
.collect()
};

let eval = |pos: LocalPosition| witness(pos.column, pos.row);

// This is set up so that on rows that have lookups, chunk will be equal
// to the product over all lookups `f` in that row of `gamma + f`
// and
// on non-lookup rows, will be equal to 1.
let f_term = |spec: &Vec<_>| {
let f_term = |spec: &Vec<JointLookupSpec<_>>| {
assert!(spec.len() <= lookup_info.max_per_row);
let padding = complements_with_beta_term[lookup_info.max_per_row - spec.len()].clone();

spec.iter()
.map(|j| E::Constant(ConstantExpr::Gamma) + joint_lookup(j))
.map(|j| {
E::Constant(ConstantExpr::Gamma)
+ j.evaluate(E::constant(ConstantExpr::JointCombiner), &eval)
})
.fold(E::Constant(padding), |acc: E<F>, x| acc * x)
};
let f_chunk = lookup_info
Expand Down