-
Notifications
You must be signed in to change notification settings - Fork 172
feat: Add Incremental Flag to Proof Generation #2337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9f10539
88d2581
c23d888
bd11940
1071c7f
c096c91
2b873eb
e035bdd
1a5593a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -64,11 +64,21 @@ export class ProofGenerator { | |||||||||||
*/ | ||||||||||||
private useQuadraticVoting?: boolean; | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Whether to use incremental proof generation | ||||||||||||
*/ | ||||||||||||
private incremental?: boolean; | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Path to the rapidsnark binary | ||||||||||||
*/ | ||||||||||||
private rapidsnark?: string; | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Whether to use incremental mode | ||||||||||||
*/ | ||||||||||||
private incremental?: boolean; | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Get maci state from local file or from contract | ||||||||||||
* | ||||||||||||
|
@@ -163,6 +173,7 @@ export class ProofGenerator { | |||||||||||
outputDir, | ||||||||||||
tallyOutputFile, | ||||||||||||
useQuadraticVoting, | ||||||||||||
incremental, | ||||||||||||
}: IProofGeneratorParams) { | ||||||||||||
this.poll = poll; | ||||||||||||
this.maciContractAddress = maciContractAddress; | ||||||||||||
|
@@ -173,6 +184,7 @@ export class ProofGenerator { | |||||||||||
this.tally = tally; | ||||||||||||
this.rapidsnark = rapidsnark; | ||||||||||||
this.useQuadraticVoting = useQuadraticVoting; | ||||||||||||
this.incremental = incremental; | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
|
@@ -185,35 +197,61 @@ export class ProofGenerator { | |||||||||||
performance.mark("mp-proofs-start"); | ||||||||||||
|
||||||||||||
const { messageBatchSize } = this.poll.batchSizes; | ||||||||||||
const numMessages = this.poll.messages.length; | ||||||||||||
let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize); | ||||||||||||
|
||||||||||||
const numMessages: number = this.poll.messages.length; | ||||||||||||
let totalMessageBatches: number = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize); | ||||||||||||
if (numMessages > messageBatchSize && numMessages % messageBatchSize > 0) { | ||||||||||||
totalMessageBatches += 1; | ||||||||||||
} | ||||||||||||
|
||||||||||||
try { | ||||||||||||
let mpCircuitInputs: CircuitInputs; | ||||||||||||
const inputs: CircuitInputs[] = []; | ||||||||||||
const proofs: Proof[] = []; | ||||||||||||
|
||||||||||||
// while we have unprocessed messages, process them | ||||||||||||
while (this.poll.hasUnprocessedMessages()) { | ||||||||||||
// process messages in batches | ||||||||||||
const circuitInputs = this.poll.processMessages( | ||||||||||||
BigInt(this.poll.pollId), | ||||||||||||
this.useQuadraticVoting, | ||||||||||||
) as unknown as CircuitInputs; | ||||||||||||
|
||||||||||||
// generate the proof for this batch | ||||||||||||
inputs.push(circuitInputs); | ||||||||||||
|
||||||||||||
logMagenta({ text: info(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`) }); | ||||||||||||
const batchIndex = this.poll.numBatchesProcessed; | ||||||||||||
const proofPath = path.join(this.outputDir, `process_${batchIndex}.json`); | ||||||||||||
|
||||||||||||
let shouldGenerateNewProof = true; | ||||||||||||
|
||||||||||||
// Check if proof exists and incremental flag is set | ||||||||||||
if (this.incremental) { | ||||||||||||
try { | ||||||||||||
// Use a synchronous approach to avoid await in loop | ||||||||||||
const exists = fs.existsSync(proofPath); | ||||||||||||
if (exists) { | ||||||||||||
// Read file synchronously to avoid await in loop | ||||||||||||
const existingProof = JSON.parse(fs.readFileSync(proofPath, "utf8")) as Proof; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still sync read here (should be async) |
||||||||||||
proofs.push(existingProof); | ||||||||||||
this.poll.numBatchesProcessed += 1; | ||||||||||||
shouldGenerateNewProof = false; | ||||||||||||
} | ||||||||||||
} catch (error) { | ||||||||||||
logMagenta({ text: info(`Error reading existing proof at ${proofPath}, regenerating...`) }); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
if (shouldGenerateNewProof) { | ||||||||||||
// process messages in batches with the incremental flag | ||||||||||||
const circuitInputs = this.poll.processMessages( | ||||||||||||
BigInt(this.poll.pollId), | ||||||||||||
this.useQuadraticVoting, | ||||||||||||
this.incremental | ||||||||||||
); | ||||||||||||
|
||||||||||||
// generate the proof for this batch | ||||||||||||
inputs.push(circuitInputs as unknown as CircuitInputs); | ||||||||||||
|
||||||||||||
logMagenta({ text: info(`Progress: ${this.poll.numBatchesProcessed} / ${totalMessageBatches}`) }); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
logMagenta({ text: info("Wait until proof generation is finished") }); | ||||||||||||
|
||||||||||||
const processZkey = await extractVk(this.mp.zkey, false); | ||||||||||||
|
||||||||||||
const proofs = await Promise.all( | ||||||||||||
const newProofs = await Promise.all( | ||||||||||||
inputs.map((circuitInputs, index) => | ||||||||||||
this.generateProofs(circuitInputs, this.mp, `process_${index}.json`, processZkey).then((data) => { | ||||||||||||
options?.onBatchComplete?.({ current: index, total: totalMessageBatches, proofs: data }); | ||||||||||||
|
@@ -222,6 +260,9 @@ export class ProofGenerator { | |||||||||||
), | ||||||||||||
).then((data) => data.reduce((acc, x) => acc.concat(x), [])); | ||||||||||||
|
||||||||||||
// Combine existing proofs with new ones | ||||||||||||
const allProofs = [...proofs, ...newProofs]; | ||||||||||||
|
||||||||||||
logGreen({ text: success("Proof generation is finished") }); | ||||||||||||
|
||||||||||||
// cleanup threads | ||||||||||||
|
@@ -230,9 +271,9 @@ export class ProofGenerator { | |||||||||||
performance.mark("mp-proofs-end"); | ||||||||||||
performance.measure("Generate message processor proofs", "mp-proofs-start", "mp-proofs-end"); | ||||||||||||
|
||||||||||||
options?.onComplete?.(proofs); | ||||||||||||
options?.onComplete?.(allProofs); | ||||||||||||
|
||||||||||||
return proofs; | ||||||||||||
return allProofs; | ||||||||||||
} catch (error) { | ||||||||||||
options?.onFail?.(error as Error); | ||||||||||||
|
||||||||||||
|
@@ -256,31 +297,68 @@ export class ProofGenerator { | |||||||||||
performance.mark("tally-proofs-start"); | ||||||||||||
|
||||||||||||
const { tallyBatchSize } = this.poll.batchSizes; | ||||||||||||
const numStateLeaves = this.poll.pollStateLeaves.length; | ||||||||||||
let totalTallyBatches = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize); | ||||||||||||
const numStateLeaves: number = this.poll.pollStateLeaves.length; | ||||||||||||
let totalTallyBatches: number = numStateLeaves <= tallyBatchSize ? 1 : Math.floor(numStateLeaves / tallyBatchSize); | ||||||||||||
if (numStateLeaves > tallyBatchSize && numStateLeaves % tallyBatchSize > 0) { | ||||||||||||
totalTallyBatches += 1; | ||||||||||||
} | ||||||||||||
|
||||||||||||
try { | ||||||||||||
let tallyCircuitInputs: CircuitInputs; | ||||||||||||
let tallyCircuitInputs: CircuitInputs | undefined; | ||||||||||||
const inputs: CircuitInputs[] = []; | ||||||||||||
const proofs: Proof[] = []; | ||||||||||||
|
||||||||||||
while (this.poll.hasUntalliedBallots()) { | ||||||||||||
tallyCircuitInputs = (this.useQuadraticVoting | ||||||||||||
? this.poll.tallyVotes() | ||||||||||||
: this.poll.tallyVotesNonQv()) as unknown as CircuitInputs; | ||||||||||||
// Load existing salts if in incremental mode | ||||||||||||
if (this.incremental) { | ||||||||||||
await this.loadSalts(); | ||||||||||||
} | ||||||||||||
|
||||||||||||
inputs.push(tallyCircuitInputs); | ||||||||||||
while (this.poll.hasUntalliedBallots()) { | ||||||||||||
const batchIndex = this.poll.numBatchesTallied; | ||||||||||||
const proofPath = path.join(this.outputDir, `tally_${batchIndex}.json`); | ||||||||||||
|
||||||||||||
let shouldGenerateNewProof = true; | ||||||||||||
|
||||||||||||
// Check if proof exists and incremental flag is set | ||||||||||||
if (this.incremental) { | ||||||||||||
Comment on lines
+317
to
+323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is duplicated in proof generation for mp, makes sense to reuse it. |
||||||||||||
try { | ||||||||||||
// Use a synchronous approach to avoid await in loop | ||||||||||||
const exists = fs.existsSync(proofPath); | ||||||||||||
if (exists) { | ||||||||||||
// Read file synchronously to avoid await in loop | ||||||||||||
const existingProof = JSON.parse(fs.readFileSync(proofPath, "utf8")) as Proof; | ||||||||||||
proofs.push(existingProof); | ||||||||||||
this.poll.numBatchesTallied += 1; | ||||||||||||
shouldGenerateNewProof = false; | ||||||||||||
} | ||||||||||||
} catch (error) { | ||||||||||||
logMagenta({ text: info(`Error reading existing proof at ${proofPath}, regenerating...`) }); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
if (shouldGenerateNewProof) { | ||||||||||||
// Use the updated interface with the incremental flag | ||||||||||||
const circuitInputs = this.useQuadraticVoting | ||||||||||||
? this.poll.tallyVotes(this.incremental) | ||||||||||||
: this.poll.tallyVotesNonQv(); | ||||||||||||
|
||||||||||||
tallyCircuitInputs = circuitInputs as unknown as CircuitInputs; | ||||||||||||
inputs.push(tallyCircuitInputs); | ||||||||||||
|
||||||||||||
logMagenta({ text: info(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`) }); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
logMagenta({ text: info(`Progress: ${this.poll.numBatchesTallied} / ${totalTallyBatches}`) }); | ||||||||||||
// Make sure we have at least one set of circuit inputs for validation | ||||||||||||
if (!tallyCircuitInputs && inputs.length > 0) { | ||||||||||||
tallyCircuitInputs = inputs[inputs.length - 1]; | ||||||||||||
} | ||||||||||||
|
||||||||||||
logMagenta({ text: info("Wait until proof generation is finished") }); | ||||||||||||
|
||||||||||||
const tallyVk = await extractVk(this.tally.zkey, false); | ||||||||||||
|
||||||||||||
const proofs = await Promise.all( | ||||||||||||
const newProofs = await Promise.all( | ||||||||||||
inputs.map((circuitInputs, index) => | ||||||||||||
this.generateProofs(circuitInputs, this.tally, `tally_${index}.json`, tallyVk).then((data) => { | ||||||||||||
options?.onBatchComplete?.({ current: index, total: totalTallyBatches, proofs: data }); | ||||||||||||
|
@@ -289,13 +367,24 @@ export class ProofGenerator { | |||||||||||
), | ||||||||||||
).then((data) => data.reduce((acc, x) => acc.concat(x), [])); | ||||||||||||
|
||||||||||||
// Combine existing proofs with new ones | ||||||||||||
const allProofs = [...proofs, ...newProofs]; | ||||||||||||
|
||||||||||||
logGreen({ text: success("Proof generation is finished") }); | ||||||||||||
|
||||||||||||
// cleanup threads | ||||||||||||
await cleanThreads(); | ||||||||||||
|
||||||||||||
// verify the results | ||||||||||||
// Compute newResultsCommitment | ||||||||||||
// If no circuit inputs were generated (all from cache), we need to get the tally results from the Poll | ||||||||||||
if (!tallyCircuitInputs) { | ||||||||||||
const statePath = path.join(this.outputDir, `poll_state.json`); | ||||||||||||
if (fs.existsSync(statePath)) { | ||||||||||||
await this.poll.loadState(statePath); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
// For validation purposes, we need to compute the various commitments | ||||||||||||
// The rest of the code remains the same | ||||||||||||
const newResultsCommitment = genTreeCommitment( | ||||||||||||
this.poll.tallyResult, | ||||||||||||
BigInt(asHex(tallyCircuitInputs!.newResultsRootSalt as BigNumberish)), | ||||||||||||
|
@@ -371,9 +460,9 @@ export class ProofGenerator { | |||||||||||
performance.mark("tally-proofs-end"); | ||||||||||||
performance.measure("Generate tally proofs", "tally-proofs-start", "tally-proofs-end"); | ||||||||||||
|
||||||||||||
options?.onComplete?.(proofs, tallyFileData); | ||||||||||||
options?.onComplete?.(allProofs, tallyFileData); | ||||||||||||
|
||||||||||||
return { proofs, tallyData: tallyFileData }; | ||||||||||||
return { proofs: allProofs, tallyData: tallyFileData }; | ||||||||||||
} catch (error) { | ||||||||||||
options?.onFail?.(error as Error); | ||||||||||||
|
||||||||||||
|
@@ -432,4 +521,101 @@ export class ProofGenerator { | |||||||||||
|
||||||||||||
return proofs; | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Generate a proof for processing messages | ||||||||||||
* @param pollId - The ID of the poll | ||||||||||||
* @param incremental - Whether to use incremental proof generation | ||||||||||||
* @returns The proof for processing messages | ||||||||||||
*/ | ||||||||||||
async generateProcessMessagesProof(pollId: bigint, incremental = false): Promise<Proof> { | ||||||||||||
const poll = this.poll as unknown as MaciState; | ||||||||||||
const foundPoll = poll.polls.get(pollId); | ||||||||||||
if (!foundPoll) { | ||||||||||||
Comment on lines
+533
to
+534
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
throw new Error(`Poll ${pollId} not found`); | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Load state if incremental mode is enabled | ||||||||||||
if (incremental) { | ||||||||||||
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`); | ||||||||||||
if (fs.existsSync(statePath)) { | ||||||||||||
Comment on lines
+540
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
await foundPoll.loadState(statePath); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Process messages | ||||||||||||
foundPoll.processMessages(pollId, this.useQuadraticVoting, incremental); | ||||||||||||
|
||||||||||||
// Save state if incremental mode is enabled | ||||||||||||
if (incremental) { | ||||||||||||
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
await foundPoll.saveState(statePath); | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Generate proof | ||||||||||||
const proofPath = path.join(this.outputDir, `processMessages_${pollId}.json`); | ||||||||||||
if (incremental && fs.existsSync(proofPath)) { | ||||||||||||
Comment on lines
+556
to
+557
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
const existingProof = JSON.parse(await fs.promises.readFile(proofPath, "utf8")) as Proof; | ||||||||||||
return existingProof; | ||||||||||||
} | ||||||||||||
|
||||||||||||
const circuitInputs = foundPoll.processMessages(pollId, this.useQuadraticVoting, incremental); | ||||||||||||
const proofs = await this.generateProofs( | ||||||||||||
circuitInputs as unknown as CircuitInputs, | ||||||||||||
this.mp, | ||||||||||||
`processMessages_${pollId}.json`, | ||||||||||||
await extractVk(this.mp.zkey, false) | ||||||||||||
); | ||||||||||||
const proof = proofs[0]; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
await fs.promises.writeFile(proofPath, JSON.stringify(proof, null, 2)); | ||||||||||||
return proof; | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Generate a proof for tallying votes | ||||||||||||
* @param pollId - The ID of the poll | ||||||||||||
* @param incremental - Whether to use incremental proof generation | ||||||||||||
* @returns The proof for tallying votes | ||||||||||||
*/ | ||||||||||||
async generateTallyProof(pollId: bigint, incremental = false): Promise<Proof> { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for tally proof is the same as for message processing, it makese sense to reuse it and do not duplicate. |
||||||||||||
const poll = this.poll as unknown as MaciState; | ||||||||||||
const foundPoll = poll.polls.get(pollId); | ||||||||||||
if (!foundPoll) { | ||||||||||||
throw new Error(`Poll ${pollId} not found`); | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Load state if incremental mode is enabled | ||||||||||||
if (incremental) { | ||||||||||||
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`); | ||||||||||||
if (fs.existsSync(statePath)) { | ||||||||||||
await foundPoll.loadState(statePath); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Tally votes | ||||||||||||
const circuitInputs = foundPoll.tallyVotes(incremental); | ||||||||||||
|
||||||||||||
// Save state if incremental mode is enabled | ||||||||||||
if (incremental) { | ||||||||||||
const statePath = path.join(this.outputDir, `poll_${pollId}_state.json`); | ||||||||||||
await foundPoll.saveState(statePath); | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Generate proof | ||||||||||||
const proofPath = path.join(this.outputDir, `tally_${pollId}.json`); | ||||||||||||
if (incremental && fs.existsSync(proofPath)) { | ||||||||||||
const existingProof = JSON.parse(await fs.promises.readFile(proofPath, "utf8")) as Proof; | ||||||||||||
return existingProof; | ||||||||||||
} | ||||||||||||
|
||||||||||||
const proofs = await this.generateProofs( | ||||||||||||
circuitInputs as unknown as CircuitInputs, | ||||||||||||
this.tally, | ||||||||||||
`tally_${pollId}.json`, | ||||||||||||
await extractVk(this.tally.zkey, false) | ||||||||||||
); | ||||||||||||
const proof = proofs[0]; | ||||||||||||
await fs.promises.writeFile(proofPath, JSON.stringify(proof, null, 2)); | ||||||||||||
return proof; | ||||||||||||
} | ||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.