Skip to content

Commit 706fbdf

Browse files
committed
feat: typeless AddressMap with typed APIs (#1559)
Note: this PR is not targeting `main`. I've used `TODO` and `TEMP` to mark places in code that will need to be cleaned up before merging to `main`. Beginning the refactor of online memory to allow different host types in different address spaces. Going to touch a lot of APIs. Focusing on stabilizing APIs - currently this PR will not improve performance. Tests will not all pass because I have intentionally disabled some logging required for trace generation. Only execution tests will pass (or run the execute benchmark). In future PR(s): - [ ] make `Memory` trait for execution read/write API - [ ] better handling of type conversions for memory image - [ ] replace the underlying memory implementation with other implementations like mmap Towards INT-3743 Even with wasteful conversions, execution is faster: Before: https://github.com/openvm-org/openvm/actions/runs/14318675080 After: https://github.com/openvm-org/openvm/actions/runs/14371335248?pr=1559
1 parent e4e180c commit 706fbdf

File tree

48 files changed

+811
-709
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+811
-709
lines changed

crates/toolchain/instructions/src/exe.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize};
55

66
use crate::program::Program;
77

8-
/// Memory image is a map from (address space, address) to word.
9-
pub type MemoryImage<F> = BTreeMap<(u32, u32), F>;
8+
// TODO[jpw]: delete this
9+
/// Memory image is a map from (address space, address * size_of<CellType>) to u8.
10+
pub type SparseMemoryImage = BTreeMap<(u32, u32), u8>;
1011
/// Stores the starting address, end address, and name of a set of function.
1112
pub type FnBounds = BTreeMap<u32, FnBound>;
1213

@@ -22,7 +23,7 @@ pub struct VmExe<F> {
2223
/// Start address of pc.
2324
pub pc_start: u32,
2425
/// Initial memory image.
25-
pub init_memory: MemoryImage<F>,
26+
pub init_memory: SparseMemoryImage,
2627
/// Starting + ending bounds for each function.
2728
pub fn_bounds: FnBounds,
2829
}
@@ -40,7 +41,7 @@ impl<F> VmExe<F> {
4041
self.pc_start = pc_start;
4142
self
4243
}
43-
pub fn with_init_memory(mut self, init_memory: MemoryImage<F>) -> Self {
44+
pub fn with_init_memory(mut self, init_memory: SparseMemoryImage) -> Self {
4445
self.init_memory = init_memory;
4546
self
4647
}

crates/toolchain/transpiler/src/util.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::BTreeMap;
22

33
use openvm_instructions::{
4-
exe::MemoryImage,
4+
exe::SparseMemoryImage,
55
instruction::Instruction,
66
riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS},
77
utils::isize_to_field,
@@ -165,17 +165,14 @@ pub fn nop<F: PrimeField32>() -> Instruction<F> {
165165
}
166166
}
167167

168-
/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as, address) -> word)
169-
pub fn elf_memory_image_to_openvm_memory_image<F: PrimeField32>(
168+
/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as=2, address) -> byte)
169+
pub fn elf_memory_image_to_openvm_memory_image(
170170
memory_image: BTreeMap<u32, u32>,
171-
) -> MemoryImage<F> {
172-
let mut result = MemoryImage::new();
171+
) -> SparseMemoryImage {
172+
let mut result = SparseMemoryImage::new();
173173
for (addr, word) in memory_image {
174174
for (i, byte) in word.to_le_bytes().into_iter().enumerate() {
175-
result.insert(
176-
(RV32_MEMORY_AS, addr + i as u32),
177-
F::from_canonical_u8(byte),
178-
);
175+
result.insert((RV32_MEMORY_AS, addr + i as u32), byte);
179176
}
180177
}
181178
result

crates/vm/src/arch/extensions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
788788
self.base.program_chip.set_program(program);
789789
}
790790

791-
pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage<F>) {
791+
pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) {
792792
self.base.memory_controller.set_initial_memory(memory);
793793
}
794794

crates/vm/src/arch/segment.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ where
145145
{
146146
pub chip_complex: VmChipComplex<F, VC::Executor, VC::Periphery>,
147147
/// Memory image after segment was executed. Not used in trace generation.
148-
pub final_memory: Option<MemoryImage<F>>,
148+
pub final_memory: Option<MemoryImage>,
149149

150150
pub since_last_segment_check: usize,
151151
pub trace_height_constraints: Vec<LinearConstraint>,
@@ -168,7 +168,7 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
168168
config: &VC,
169169
program: Program<F>,
170170
init_streams: Streams<F>,
171-
initial_memory: Option<MemoryImage<F>>,
171+
initial_memory: Option<MemoryImage>,
172172
trace_height_constraints: Vec<LinearConstraint>,
173173
#[allow(unused_variables)] fn_bounds: FnBounds,
174174
) -> Self {

crates/vm/src/arch/vm.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ pub enum GenerationError {
4747
}
4848

4949
/// VM memory state for continuations.
50-
pub type VmMemoryState<F> = MemoryImage<F>;
5150
5251
#[derive(Clone, Default, Debug)]
5352
pub struct Streams<F> {
@@ -95,19 +94,19 @@ pub enum ExitCode {
9594
pub struct VmExecutorResult<SC: StarkGenericConfig> {
9695
pub per_segment: Vec<ProofInput<SC>>,
9796
/// When VM is running on persistent mode, public values are stored in a special memory space.
98-
pub final_memory: Option<VmMemoryState<Val<SC>>>,
97+
pub final_memory: Option<MemoryImage>,
9998
}
10099

101100
pub struct VmExecutorNextSegmentState<F: PrimeField32> {
102-
pub memory: MemoryImage<F>,
101+
pub memory: MemoryImage,
103102
pub input: Streams<F>,
104103
pub pc: u32,
105104
#[cfg(feature = "bench-metrics")]
106105
pub metrics: VmMetrics,
107106
}
108107

109108
impl<F: PrimeField32> VmExecutorNextSegmentState<F> {
110-
pub fn new(memory: MemoryImage<F>, input: impl Into<Streams<F>>, pc: u32) -> Self {
109+
pub fn new(memory: MemoryImage, input: impl Into<Streams<F>>, pc: u32) -> Self {
111110
Self {
112111
memory,
113112
input: input.into(),
@@ -170,12 +169,13 @@ where
170169
let mem_config = self.config.system().memory_config;
171170
let exe = exe.into();
172171
let mut segment_results = vec![];
173-
let memory = AddressMap::from_iter(
172+
let memory = AddressMap::from_sparse(
174173
mem_config.as_offset,
175174
1 << mem_config.as_height,
176175
1 << mem_config.pointer_max_bits,
177176
exe.init_memory.clone(),
178177
);
178+
179179
let pc = exe.pc_start;
180180
let mut state = VmExecutorNextSegmentState::new(memory, input, pc);
181181
let mut segment_idx = 0;
@@ -272,7 +272,7 @@ where
272272
&self,
273273
exe: impl Into<VmExe<F>>,
274274
input: impl Into<Streams<F>>,
275-
) -> Result<Option<VmMemoryState<F>>, ExecutionError> {
275+
) -> Result<Option<MemoryImage>, ExecutionError> {
276276
let mut last = None;
277277
self.execute_and_then(
278278
exe,
@@ -581,7 +581,7 @@ where
581581
&self,
582582
exe: impl Into<VmExe<F>>,
583583
input: impl Into<Streams<F>>,
584-
) -> Result<Option<VmMemoryState<F>>, ExecutionError> {
584+
) -> Result<Option<MemoryImage>, ExecutionError> {
585585
self.executor.execute(exe, input)
586586
}
587587

crates/vm/src/system/memory/controller/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub enum MemoryInterface<F> {
1313
Persistent {
1414
boundary_chip: PersistentBoundaryChip<F, CHUNK>,
1515
merkle_chip: MemoryMerkleChip<CHUNK, F>,
16-
initial_memory: MemoryImage<F>,
16+
initial_memory: MemoryImage,
1717
},
1818
}
1919

crates/vm/src/system/memory/controller/mod.rs

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::{
33
collections::BTreeMap,
44
iter,
55
marker::PhantomData,
6-
mem,
76
sync::{Arc, Mutex},
87
};
98

@@ -62,7 +61,7 @@ pub const BOUNDARY_AIR_OFFSET: usize = 0;
6261
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6362
pub struct RecordId(pub usize);
6463

65-
pub type MemoryImage<F> = AddressMap<F, PAGE_SIZE>;
64+
pub type MemoryImage = AddressMap<PAGE_SIZE>;
6665

6766
#[repr(C)]
6867
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -98,7 +97,7 @@ pub struct MemoryController<F> {
9897
// Store separately to avoid smart pointer reference each time
9998
range_checker_bus: VariableRangeCheckerBus,
10099
// addr_space -> Memory data structure
101-
memory: Memory<F>,
100+
memory: Memory,
102101
/// A reference to the `OfflineMemory`. Will be populated after `finalize()`.
103102
offline_memory: Arc<Mutex<OfflineMemory<F>>>,
104103
pub access_adapters: AccessAdapterInventory<F>,
@@ -314,7 +313,7 @@ impl<F: PrimeField32> MemoryController<F> {
314313
}
315314
}
316315

317-
pub fn memory_image(&self) -> &MemoryImage<F> {
316+
pub fn memory_image(&self) -> &MemoryImage {
318317
&self.memory.data
319318
}
320319

@@ -344,7 +343,7 @@ impl<F: PrimeField32> MemoryController<F> {
344343
}
345344
}
346345

347-
pub fn set_initial_memory(&mut self, memory: MemoryImage<F>) {
346+
pub fn set_initial_memory(&mut self, memory: MemoryImage) {
348347
if self.timestamp() > INITIAL_TIMESTAMP + 1 {
349348
panic!("Cannot set initial memory after first timestamp");
350349
}
@@ -379,58 +378,67 @@ impl<F: PrimeField32> MemoryController<F> {
379378
(record_id, data)
380379
}
381380

382-
pub fn read<const N: usize>(&mut self, address_space: F, pointer: F) -> (RecordId, [F; N]) {
381+
// TEMP[jpw]: Function is safe temporarily for refactoring
382+
/// # Safety
383+
/// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, and it must be the
384+
/// exact type used to represent a single memory cell in address space `address_space`. For
385+
/// standard usage, `T` is either `u8` or `F` where `F` is the base field of the ZK backend.
386+
pub fn read<T: Copy, const N: usize>(
387+
&mut self,
388+
address_space: F,
389+
pointer: F,
390+
) -> (RecordId, [T; N]) {
383391
let address_space_u32 = address_space.as_canonical_u32();
384392
let ptr_u32 = pointer.as_canonical_u32();
385393
assert!(
386394
address_space == F::ZERO || ptr_u32 < (1 << self.mem_config.pointer_max_bits),
387395
"memory out of bounds: {ptr_u32:?}",
388396
);
389397

390-
let (record_id, values) = self.memory.read::<N>(address_space_u32, ptr_u32);
398+
let (record_id, values) = unsafe { self.memory.read::<T, N>(address_space_u32, ptr_u32) };
391399

392400
(record_id, values)
393401
}
394402

395403
/// Reads a word directly from memory without updating internal state.
396404
///
397405
/// Any value returned is unconstrained.
398-
pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> F {
399-
self.unsafe_read::<1>(addr_space, ptr)[0]
406+
pub fn unsafe_read_cell<T: Copy>(&self, addr_space: F, ptr: F) -> T {
407+
self.unsafe_read::<T, 1>(addr_space, ptr)[0]
400408
}
401409

402410
/// Reads a word directly from memory without updating internal state.
403411
///
404412
/// Any value returned is unconstrained.
405-
pub fn unsafe_read<const N: usize>(&self, addr_space: F, ptr: F) -> [F; N] {
413+
pub fn unsafe_read<T: Copy, const N: usize>(&self, addr_space: F, ptr: F) -> [T; N] {
406414
let addr_space = addr_space.as_canonical_u32();
407415
let ptr = ptr.as_canonical_u32();
408-
array::from_fn(|i| self.memory.get(addr_space, ptr + i as u32))
416+
unsafe { array::from_fn(|i| self.memory.get::<T>(addr_space, ptr + i as u32)) }
409417
}
410418

411419
/// Writes `data` to the given cell.
412420
///
413421
/// Returns the `RecordId` and previous data.
414-
pub fn write_cell(&mut self, address_space: F, pointer: F, data: F) -> (RecordId, F) {
415-
let (record_id, [data]) = self.write(address_space, pointer, [data]);
422+
pub fn write_cell<T: Copy>(&mut self, address_space: F, pointer: F, data: T) -> (RecordId, T) {
423+
let (record_id, [data]) = self.write(address_space, pointer, &[data]);
416424
(record_id, data)
417425
}
418426

419-
pub fn write<const N: usize>(
427+
pub fn write<T: Copy, const N: usize>(
420428
&mut self,
421429
address_space: F,
422430
pointer: F,
423-
data: [F; N],
424-
) -> (RecordId, [F; N]) {
425-
assert_ne!(address_space, F::ZERO);
431+
data: &[T; N],
432+
) -> (RecordId, [T; N]) {
433+
debug_assert_ne!(address_space, F::ZERO);
426434
let address_space_u32 = address_space.as_canonical_u32();
427435
let ptr_u32 = pointer.as_canonical_u32();
428436
assert!(
429437
ptr_u32 < (1 << self.mem_config.pointer_max_bits),
430438
"memory out of bounds: {ptr_u32:?}",
431439
);
432440

433-
self.memory.write(address_space_u32, ptr_u32, data)
441+
unsafe { self.memory.write::<T, N>(address_space_u32, ptr_u32, data) }
434442
}
435443

436444
pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory<F> {
@@ -455,26 +463,27 @@ impl<F: PrimeField32> MemoryController<F> {
455463
}
456464

457465
fn replay_access_log(&mut self) {
458-
let log = mem::take(&mut self.memory.log);
459-
if log.is_empty() {
460-
// Online memory logs may be empty, but offline memory may be replayed from external
461-
// sources. In these cases, we skip the calls to replay access logs because
462-
// `set_log_capacity` would panic.
463-
tracing::debug!("skipping replay_access_log");
464-
return;
465-
}
466-
467-
let mut offline_memory = self.offline_memory.lock().unwrap();
468-
offline_memory.set_log_capacity(log.len());
469-
470-
for entry in log {
471-
Self::replay_access(
472-
entry,
473-
&mut offline_memory,
474-
&mut self.interface_chip,
475-
&mut self.access_adapters,
476-
);
477-
}
466+
unimplemented!();
467+
// let log = mem::take(&mut self.memory.log);
468+
// if log.is_empty() {
469+
// // Online memory logs may be empty, but offline memory may be replayed from external
470+
// sources. // In these cases, we skip the calls to replay access logs because
471+
// `set_log_capacity` would // panic.
472+
// tracing::debug!("skipping replay_access_log");
473+
// return;
474+
// }
475+
476+
// let mut offline_memory = self.offline_memory.lock().unwrap();
477+
// offline_memory.set_log_capacity(log.len());
478+
479+
// for entry in log {
480+
// Self::replay_access(
481+
// entry,
482+
// &mut offline_memory,
483+
// &mut self.interface_chip,
484+
// &mut self.access_adapters,
485+
// );
486+
// }
478487
}
479488

480489
/// Low-level API to replay a single memory access log entry and populate the [OfflineMemory],
@@ -704,13 +713,13 @@ impl<F: PrimeField32> MemoryController<F> {
704713
pub fn offline_memory(&self) -> Arc<Mutex<OfflineMemory<F>>> {
705714
self.offline_memory.clone()
706715
}
707-
pub fn get_memory_logs(&self) -> &Vec<MemoryLogEntry<F>> {
716+
pub fn get_memory_logs(&self) -> &Vec<MemoryLogEntry<u8>> {
708717
&self.memory.log
709718
}
710-
pub fn set_memory_logs(&mut self, logs: Vec<MemoryLogEntry<F>>) {
719+
pub fn set_memory_logs(&mut self, logs: Vec<MemoryLogEntry<u8>>) {
711720
self.memory.log = logs;
712721
}
713-
pub fn take_memory_logs(&mut self) -> Vec<MemoryLogEntry<F>> {
722+
pub fn take_memory_logs(&mut self) -> Vec<MemoryLogEntry<u8>> {
714723
std::mem::take(&mut self.memory.log)
715724
}
716725
}
@@ -857,9 +866,9 @@ mod tests {
857866

858867
if rng.gen_bool(0.5) {
859868
let data = F::from_canonical_u32(rng.gen_range(0..1 << 30));
860-
memory_controller.write(address_space, pointer, [data]);
869+
memory_controller.write(address_space, pointer, &[data]);
861870
} else {
862-
memory_controller.read::<1>(address_space, pointer);
871+
memory_controller.read::<F, 1>(address_space, pointer);
863872
}
864873
}
865874
assert!(memory_controller

0 commit comments

Comments
 (0)