Skip to content

feat: Add Incremental Flag to Proof Generation #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion packages/cli/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,15 @@ program
.option(
"-u, --use-quadratic-voting <useQuadraticVoting>",
"whether to use quadratic voting",
(value) => value === "true",
(value: string) => value === "true",
true,
)
.option(
"-i, --incremental <incremental>",
"whether to reuse existing proofs and salts (only regenerate missing ones)",
(value: string) => value === "true",
false,
)
.option(
"-b, --ipfs-message-backup-files <ipfsMessageBackupFiles>",
"Backup files for ipfs messages (name format: ipfsHash1.json, ipfsHash2.json, ..., ipfsHashN.json)",
Expand Down Expand Up @@ -974,6 +980,7 @@ program
processWitnessdat,
wasm,
rapidsnark,
incremental,
}) => {
try {
banner(quiet);
Expand Down Expand Up @@ -1012,6 +1019,7 @@ program
processDatFile: processWitnessdat,
useWasm: wasm,
rapidsnark,
incremental,
});
} catch (error) {
program.error((error as Error).message, { exitCode: 1 });
Expand Down
246 changes: 216 additions & 30 deletions packages/contracts/tasks/helpers/ProofGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,21 @@ export class ProofGenerator {
*/
private useQuadraticVoting?: boolean;

/**
* Whether to use incremental proof generation
*/
private incremental?: boolean;

/**
* Path to the rapidsnark binary
*/
private rapidsnark?: string;

/**
* Whether to use incremental mode
*/
private incremental?: boolean;

/**
* Get maci state from local file or from contract
*
Expand Down Expand Up @@ -163,6 +173,7 @@ export class ProofGenerator {
outputDir,
tallyOutputFile,
useQuadraticVoting,
incremental,
}: IProofGeneratorParams) {
this.poll = poll;
this.maciContractAddress = maciContractAddress;
Expand All @@ -173,6 +184,7 @@ export class ProofGenerator {
this.tally = tally;
this.rapidsnark = rapidsnark;
this.useQuadraticVoting = useQuadraticVoting;
this.incremental = incremental;
}

/**
Expand All @@ -185,35 +197,61 @@ export class ProofGenerator {
performance.mark("mp-proofs-start");

const { messageBatchSize } = this.poll.batchSizes;
const numMessages = this.poll.messages.length;
let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize);

const numMessages: number = this.poll.messages.length;
let totalMessageBatches: number = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize);
if (numMessages > messageBatchSize && numMessages % messageBatchSize > 0) {
totalMessageBatches += 1;
}

try {
let mpCircuitInputs: CircuitInputs;
const inputs: CircuitInputs[] = [];
const proofs: Proof[] = [];

// while we have unprocessed messages, process them
while (this.poll.hasUnprocessedMessages()) {
// process messages in batches
const circuitInputs = this.poll.processMessages(
BigInt(this.poll.pollId),
this.useQuadraticVoting,
) as unknown as CircuitInputs;

// generate the proof for this batch
inputs.push(circuitInputs);

logMagenta({ text: info(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`) });
const batchIndex = this.poll.numBatchesProcessed;
const proofPath = path.join(this.outputDir, `process_${batchIndex}.json`);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const proofPath = path.join(this.outputDir, `process_${batchIndex}.json`);
const proofPath = path.resolve(this.outputDir, `process_${batchIndex}.json`);


let shouldGenerateNewProof = true;

// Check if proof exists and incremental flag is set
if (this.incremental) {
try {
// Use a synchronous approach to avoid await in loop
const exists = fs.existsSync(proofPath);
if (exists) {
// Read file synchronously to avoid await in loop
const existingProof = JSON.parse(fs.readFileSync(proofPath, "utf8")) as Proof;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still sync read here (should be async)

proofs.push(existingProof);
this.poll.numBatchesProcessed += 1;
shouldGenerateNewProof = false;
}
} catch (error) {
logMagenta({ text: info(`Error reading existing proof at ${proofPath}, regenerating...`) });
}
}

if (shouldGenerateNewProof) {
// process messages in batches with the incremental flag
const circuitInputs = this.poll.processMessages(
BigInt(this.poll.pollId),
this.useQuadraticVoting,
this.incremental
);

// generate the proof for this batch
inputs.push(circuitInputs as unknown as CircuitInputs);

logMagenta({ text: info(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`) });
}
}

logMagenta({ text: info("Wait until proof generation is finished") });

const processZkey = await extractVk(this.mp.zkey, false);

const proofs = await Promise.all(
const newProofs = await Promise.all(
inputs.map((circuitInputs, index) =>
this.generateProofs(circuitInputs, this.mp, `process_${index}.json`, processZkey).then((data) => {
options?.onBatchComplete?.({ current: index, total: totalMessageBatches, proofs: data });
Expand All @@ -222,6 +260,9 @@ export class ProofGenerator {
),
).then((data) => data.reduce((acc, x) => acc.concat(x), []));

// Combine existing proofs with new ones
const allProofs = [...proofs, ...newProofs];

logGreen({ text: success("Proof generation is finished") });

// cleanup threads
Expand All @@ -230,9 +271,9 @@ export class ProofGenerator {
performance.mark("mp-proofs-end");
performance.measure("Generate message processor proofs", "mp-proofs-start", "mp-proofs-end");

options?.onComplete?.(proofs);
options?.onComplete?.(allProofs);

return proofs;
return allProofs;
} catch (error) {
options?.onFail?.(error as Error);

Expand All @@ -256,31 +297,68 @@ export class ProofGenerator {
performance.mark("tally-proofs-start");

const { tallyBatchSize } = this.poll.batchSizes;
const numStateLeaves = this.poll.pollStateLeaves.length;
let totalTallyBatches = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize);
const numStateLeaves: number = this.poll.pollStateLeaves.length;
let totalTallyBatches: number = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize);
if (numStateLeaves > tallyBatchSize && numStateLeaves % tallyBatchSize > 0) {
totalTallyBatches += 1;
}

try {
let tallyCircuitInputs: CircuitInputs;
let tallyCircuitInputs: CircuitInputs | undefined;
const inputs: CircuitInputs[] = [];
const proofs: Proof[] = [];

while (this.poll.hasUntalliedBallots()) {
tallyCircuitInputs = (this.useQuadraticVoting
? this.poll.tallyVotes()
: this.poll.tallyVotesNonQv()) as unknown as CircuitInputs;
// Load existing salts if in incremental mode
if (this.incremental) {
await this.loadSalts();
}

inputs.push(tallyCircuitInputs);
while (this.poll.hasUntalliedBallots()) {
const batchIndex = this.poll.numBatchesTallied;
const proofPath = path.join(this.outputDir, `tally_${batchIndex}.json`);

let shouldGenerateNewProof = true;

// Check if proof exists and incremental flag is set
if (this.incremental) {
Comment on lines +317 to +323
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is duplicated in proof generation for mp, makes sense to reuse it.

try {
// Use a synchronous approach to avoid await in loop
const exists = fs.existsSync(proofPath);
if (exists) {
// Read file synchronously to avoid await in loop
const existingProof = JSON.parse(fs.readFileSync(proofPath, "utf8")) as Proof;
proofs.push(existingProof);
this.poll.numBatchesTallied += 1;
shouldGenerateNewProof = false;
}
} catch (error) {
logMagenta({ text: info(`Error reading existing proof at ${proofPath}, regenerating...`) });
}
}

if (shouldGenerateNewProof) {
// Use the updated interface with the incremental flag
const circuitInputs = this.useQuadraticVoting
? this.poll.tallyVotes(this.incremental)
: this.poll.tallyVotesNonQv();

tallyCircuitInputs = circuitInputs as unknown as CircuitInputs;
inputs.push(tallyCircuitInputs);

logMagenta({ text: info(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`) });
}
}

logMagenta({ text: info(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`) });
// Make sure we have at least one set of circuit inputs for validation
if (!tallyCircuitInputs && inputs.length > 0) {
tallyCircuitInputs = inputs[inputs.length - 1];
}

logMagenta({ text: info("Wait until proof generation is finished") });

const tallyVk = await extractVk(this.tally.zkey, false);

const proofs = await Promise.all(
const newProofs = await Promise.all(
inputs.map((circuitInputs, index) =>
this.generateProofs(circuitInputs, this.tally, `tally_${index}.json`, tallyVk).then((data) => {
options?.onBatchComplete?.({ current: index, total: totalTallyBatches, proofs: data });
Expand All @@ -289,13 +367,24 @@ export class ProofGenerator {
),
).then((data) => data.reduce((acc, x) => acc.concat(x), []));

// Combine existing proofs with new ones
const allProofs = [...proofs, ...newProofs];

logGreen({ text: success("Proof generation is finished") });

// cleanup threads
await cleanThreads();

// verify the results
// Compute newResultsCommitment
// If no circuit inputs were generated (all from cache), we need to get the tally results from the Poll
if (!tallyCircuitInputs) {
const statePath = path.join(this.outputDir, `poll_state.json`);
if (fs.existsSync(statePath)) {
await this.poll.loadState(statePath);
}
}

// For validation purposes, we need to compute the various commitments
// The rest of the code remains the same
const newResultsCommitment = genTreeCommitment(
this.poll.tallyResult,
BigInt(asHex(tallyCircuitInputs!.newResultsRootSalt as BigNumberish)),
Expand Down Expand Up @@ -371,9 +460,9 @@ export class ProofGenerator {
performance.mark("tally-proofs-end");
performance.measure("Generate tally proofs", "tally-proofs-start", "tally-proofs-end");

options?.onComplete?.(proofs, tallyFileData);
options?.onComplete?.(allProofs, tallyFileData);

return { proofs, tallyData: tallyFileData };
return { proofs: allProofs, tallyData: tallyFileData };
} catch (error) {
options?.onFail?.(error as Error);

Expand Down Expand Up @@ -432,4 +521,101 @@ export class ProofGenerator {

return proofs;
}

/**
* Generate a proof for processing messages
* @param pollId - The ID of the poll
* @param incremental - Whether to use incremental proof generation
* @returns The proof for processing messages
*/
async generateProcessMessagesProof(pollId: bigint, incremental = false): Promise<Proof> {
const poll = this.poll as unknown as MaciState;
const foundPoll = poll.polls.get(pollId);
if (!foundPoll) {
Comment on lines +533 to +534
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const foundPoll = poll.polls.get(pollId);
if (!foundPoll) {
const foundPoll = poll.polls.get(pollId);
if (!foundPoll) {

throw new Error(`Poll ${pollId} not found`);
}

// Load state if incremental mode is enabled
if (incremental) {
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
if (fs.existsSync(statePath)) {
Comment on lines +540 to +541
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
if (fs.existsSync(statePath)) {
const statePath = path.resolve(this.outputDir, `poll_${pollId}_state.json`);
if (fs.existsSync(statePath)) {

await foundPoll.loadState(statePath);
}
}

// Process messages
foundPoll.processMessages(pollId, this.useQuadraticVoting, incremental);

// Save state if incremental mode is enabled
if (incremental) {
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
const statePath = path.resolve(this.outputDir, `poll_${pollId}_state.json`);

await foundPoll.saveState(statePath);
}

// Generate proof
const proofPath = path.join(this.outputDir, `processMessages_${pollId}.json`);
if (incremental && fs.existsSync(proofPath)) {
Comment on lines +556 to +557
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const proofPath = path.join(this.outputDir, `processMessages_${pollId}.json`);
if (incremental && fs.existsSync(proofPath)) {
const proofPath = path.resolve(this.outputDir, `processMessages_${pollId}.json`);
if (incremental && fs.existsSync(proofPath)) {

const existingProof = JSON.parse(await fs.promises.readFile(proofPath, "utf8")) as Proof;
return existingProof;
}

const circuitInputs = foundPoll.processMessages(pollId, this.useQuadraticVoting, incremental);
const proofs = await this.generateProofs(
circuitInputs as unknown as CircuitInputs,
this.mp,
`processMessages_${pollId}.json`,
await extractVk(this.mp.zkey, false)
);
const proof = proofs[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const proof = proofs[0];
const [proof] = proofs;

await fs.promises.writeFile(proofPath, JSON.stringify(proof, null, 2));
return proof;
}

/**
* Generate a proof for tallying votes
* @param pollId - The ID of the poll
* @param incremental - Whether to use incremental proof generation
* @returns The proof for tallying votes
*/
async generateTallyProof(pollId: bigint, incremental = false): Promise<Proof> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for tally proof is the same as for message processing, it makese sense to reuse it and do not duplicate.

const poll = this.poll as unknown as MaciState;
const foundPoll = poll.polls.get(pollId);
if (!foundPoll) {
throw new Error(`Poll ${pollId} not found`);
}

// Load state if incremental mode is enabled
if (incremental) {
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
if (fs.existsSync(statePath)) {
await foundPoll.loadState(statePath);
}
}

// Tally votes
const circuitInputs = foundPoll.tallyVotes(incremental);

// Save state if incremental mode is enabled
if (incremental) {
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`);
await foundPoll.saveState(statePath);
}

// Generate proof
const proofPath = path.join(this.outputDir, `tally_${pollId}.json`);
if (incremental && fs.existsSync(proofPath)) {
const existingProof = JSON.parse(await fs.promises.readFile(proofPath, "utf8")) as Proof;
return existingProof;
}

const proofs = await this.generateProofs(
circuitInputs as unknown as CircuitInputs,
this.tally,
`tally_${pollId}.json`,
await extractVk(this.tally.zkey, false)
);
const proof = proofs[0];
await fs.promises.writeFile(proofPath, JSON.stringify(proof, null, 2));
return proof;
}
}
Loading
Loading