Skip to content

Commit a78054f

Browse files
committed
cleanup all allocations in WASM
1 parent f2d220a commit a78054f

File tree

6 files changed

+163
-100
lines changed

6 files changed

+163
-100
lines changed

wasmtime-jni/src/ty/byte_slice.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use log::debug;
55
use wasmtime::{Val, ValType, WeakStore};
66
pub use wasmtime_jni_exports::WasmSlice;
77

8-
use crate::ty::{Abi, ComplexTy, ReturnAbi, WasmAlloc};
8+
use crate::ty::{Abi, ComplexTy, ReturnAbi, WasmAlloc, WasmSliceWrapper};
99

1010
// pub fn greet(name: &str) -> String {
1111
// format!("Hello, {}!", name)
@@ -145,17 +145,17 @@ impl ReturnAbi for WasmSlice {
145145

146146
/// Place the values in the argument list, if there was an allocation, the pointer is returned
147147
#[allow(unused)]
148-
fn return_or_store_to_arg(
148+
fn return_or_store_to_arg<'w>(
149149
args: &mut Vec<Val>,
150-
wasm_alloc: Option<&mut WasmAlloc>,
151-
) -> Result<Option<i32>, Error> {
150+
wasm_alloc: Option<&'w WasmAlloc>,
151+
) -> Result<Option<WasmSliceWrapper<'w>>, Error> {
152152
// create a place in memory for the slice to be returned
153-
let ptr = wasm_alloc
153+
let slice = wasm_alloc
154154
.ok_or_else(|| anyhow!("WasmAlloc not supplied"))?
155155
.alloc::<Self>()?;
156156

157-
args.push(Val::from(ptr));
158-
Ok(Some(ptr))
157+
args.push(Val::from(slice.ptr));
158+
Ok(Some(slice))
159159
}
160160

161161
fn get_return_by_ref_arg(mut args: impl Iterator<Item = Val>) -> Option<i32> {
@@ -165,15 +165,14 @@ impl ReturnAbi for WasmSlice {
165165
/// Load from the returned value, or from the passed in pointer to the return by ref parameter
166166
fn return_or_load_or_from_args(
167167
_ret: Option<&Val>,
168-
mut ret_by_ref_ptr: Option<i32>,
169-
wasm_alloc: Option<&mut WasmAlloc>,
168+
mut ret_by_ref_ptr: Option<WasmSliceWrapper<'_>>,
169+
_wasm_alloc: Option<&WasmAlloc>,
170170
) -> Result<Self, anyhow::Error> {
171171
let ptr = ret_by_ref_ptr
172172
.take()
173173
.ok_or_else(|| anyhow!("No pointer was supplied"))?;
174-
let wasm_alloc = wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc was not supplied"))?;
174+
let wasm_slice = unsafe { ptr.obj_as_mut() };
175175

176-
let wasm_slice = unsafe { wasm_alloc.obj_as_mut(ptr) };
177176
debug!("read {:?}", wasm_slice);
178177
Ok(*wasm_slice)
179178
}

wasmtime-jni/src/ty/complex_ty.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use anyhow::{anyhow, ensure, Error};
22
use wasmtime::{Val, ValType, WeakStore};
33

4-
use crate::ty::WasmAlloc;
4+
use crate::ty::{WasmAlloc, WasmSliceWrapper};
55

66
pub(crate) trait ComplexTy {
77
type Abi: Abi;
@@ -38,16 +38,16 @@ pub(crate) trait ReturnAbi: Abi {
3838

3939
/// Place the values in the argument list
4040
#[allow(unused)]
41-
fn return_or_store_to_arg(
41+
fn return_or_store_to_arg<'w>(
4242
args: &mut Vec<Val>,
43-
wasm_alloc: Option<&mut WasmAlloc>,
44-
) -> Result<Option<i32>, Error>;
43+
wasm_alloc: Option<&'w WasmAlloc>,
44+
) -> Result<Option<WasmSliceWrapper<'w>>, Error>;
4545

4646
/// Load from the argument list
4747
fn return_or_load_or_from_args(
4848
ret: Option<&Val>,
49-
ret_by_ref_ptr: Option<i32>,
50-
wasm_alloc: Option<&mut WasmAlloc>,
49+
ret_by_ref_ptr: Option<WasmSliceWrapper<'_>>,
50+
wasm_alloc: Option<&WasmAlloc>,
5151
) -> Result<Self, anyhow::Error>;
5252
}
5353

@@ -60,10 +60,10 @@ impl<T: Abi + IntoValType + FromVal + MatchesValType> ReturnAbi for T {
6060

6161
/// Place the values in the argument list
6262
#[allow(unused)]
63-
fn return_or_store_to_arg(
63+
fn return_or_store_to_arg<'w>(
6464
args: &mut Vec<Val>,
65-
wasm_alloc: Option<&mut WasmAlloc>,
66-
) -> Result<Option<i32>, Error> {
65+
wasm_alloc: Option<&'w WasmAlloc>,
66+
) -> Result<Option<WasmSliceWrapper<'w>>, Error> {
6767
Ok(None)
6868
}
6969

@@ -74,8 +74,8 @@ impl<T: Abi + IntoValType + FromVal + MatchesValType> ReturnAbi for T {
7474
/// Load from the argument list
7575
fn return_or_load_or_from_args(
7676
mut ret: Option<&Val>,
77-
_ret_by_ref_ptr: Option<i32>,
78-
_wasm_alloc: Option<&mut WasmAlloc>,
77+
_ret_by_ref_ptr: Option<WasmSliceWrapper<'_>>,
78+
_wasm_alloc: Option<&WasmAlloc>,
7979
) -> Result<Self, anyhow::Error> {
8080
ret.take()
8181
.cloned()

wasmtime-jni/src/ty/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ mod wasm_alloc;
44

55
pub(crate) use byte_slice::WasmSlice;
66
pub(crate) use complex_ty::{Abi, ComplexTy, ReturnAbi};
7-
pub(crate) use wasm_alloc::WasmAlloc;
7+
pub(crate) use wasm_alloc::{WasmAlloc, WasmSliceWrapper};

wasmtime-jni/src/ty/wasm_alloc.rs

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::mem;
2+
use std::ops::Deref;
23

34
use anyhow::{anyhow, ensure, Context, Error};
4-
use log::debug;
5+
use log::{debug, warn};
56
use wasmtime::{Caller, Extern, Func, Instance, Memory, Val};
67
use wasmtime_jni_exports::{ALLOC_EXPORT, DEALLOC_EXPORT, MEMORY_EXPORT};
78

@@ -49,14 +50,14 @@ impl WasmAlloc {
4950
}
5051

5152
/// Safety, the returned array is uninitialized
52-
pub unsafe fn as_mut(&mut self, wasm_slice: WasmSlice) -> &mut [u8] {
53+
pub unsafe fn as_mut(&self, wasm_slice: WasmSlice) -> &mut [u8] {
5354
debug!("data ptr: {}", wasm_slice.ptr);
5455

5556
&mut self.memory.data_unchecked_mut()[wasm_slice.ptr as usize..][..wasm_slice.len as usize]
5657
}
5758

5859
/// Allocates size bytes in the Wasm Memory context, returns the offset into the Memory region
59-
pub unsafe fn alloc_size(&self, size: usize) -> Result<WasmSlice, Error> {
60+
pub unsafe fn alloc_size(&self, size: usize) -> Result<WasmSliceWrapper<'_>, Error> {
6061
let len = size as i32;
6162
let ptr = self
6263
.alloc
@@ -67,11 +68,12 @@ impl WasmAlloc {
6768

6869
debug!("Allocated offset {} len {}", ptr, len);
6970

70-
Ok(WasmSlice { ptr, len })
71+
let wasm_slice = WasmSlice { ptr, len };
72+
Ok(WasmSliceWrapper::new(self, wasm_slice))
7173
}
7274

7375
/// Allocates the bytes from the src bytes
74-
pub fn alloc_bytes(&mut self, src: &[u8]) -> Result<WasmSlice, Error> {
76+
pub fn alloc_bytes(&self, src: &[u8]) -> Result<WasmSliceWrapper<'_>, Error> {
7577
let mem_base = self.memory.data_ptr() as usize;
7678
let mem_size = self.memory.size() as usize * MEM_SEGMENT_SIZE;
7779
ensure!(
@@ -83,7 +85,7 @@ impl WasmAlloc {
8385

8486
// get target memor location and then copy into the function
8587
let wasm_slice = unsafe { self.alloc_size(src.len())? };
86-
let mem_bytes = unsafe { self.as_mut(wasm_slice) };
88+
let mem_bytes = unsafe { self.as_mut(wasm_slice.wasm_slice) };
8789
mem_bytes.copy_from_slice(src);
8890

8991
debug!(
@@ -107,23 +109,23 @@ impl WasmAlloc {
107109
Ok(())
108110
}
109111

110-
pub fn alloc<T: Sized>(&mut self) -> Result<i32, Error> {
112+
pub fn alloc<T: Sized>(&self) -> Result<WasmSliceWrapper<'_>, Error> {
111113
let wasm_slice = unsafe { self.alloc_size(mem::size_of::<T>())? };
112114

113115
// zero out the memory...
114-
for b in unsafe { self.as_mut(wasm_slice) } {
116+
for b in unsafe { wasm_slice.as_mut() } {
115117
*b = 0;
116118
}
117119

118120
debug!(
119121
"stored {} at {:x?}",
120122
std::any::type_name::<T>(),
121-
wasm_slice.ptr
123+
wasm_slice.wasm_slice().ptr
122124
);
123-
Ok(wasm_slice.ptr)
125+
Ok(wasm_slice)
124126
}
125127

126-
pub unsafe fn obj_as_mut<T: Sized>(&mut self, ptr: i32) -> &mut T {
128+
pub unsafe fn obj_as_mut<T: Sized>(&self, ptr: i32) -> &mut T {
127129
debug_assert!(ptr > 0);
128130
let ptr_to_mem = self.memory.data_ptr().add(ptr as usize);
129131
debug!("dereffing {:x?} from offset {:x?}", ptr_to_mem, ptr);
@@ -140,3 +142,48 @@ impl WasmAlloc {
140142
self.dealloc_bytes(wasm_slice)
141143
}
142144
}
145+
146+
/// This is use to free memory after a function call
147+
pub(crate) struct WasmSliceWrapper<'w> {
148+
wasm_alloc: &'w WasmAlloc,
149+
wasm_slice: WasmSlice,
150+
}
151+
152+
impl<'w> WasmSliceWrapper<'w> {
153+
pub fn new(wasm_alloc: &'w WasmAlloc, wasm_slice: WasmSlice) -> Self {
154+
Self {
155+
wasm_alloc,
156+
wasm_slice,
157+
}
158+
}
159+
160+
/// Safety, the returned array is uninitialized
161+
pub unsafe fn as_mut(&self) -> &mut [u8] {
162+
self.wasm_alloc.as_mut(self.wasm_slice)
163+
}
164+
165+
pub unsafe fn obj_as_mut<T: Sized>(&self) -> &mut T {
166+
self.wasm_alloc.obj_as_mut(self.wasm_slice.ptr)
167+
}
168+
169+
/// Copy out the WasmSlice, careful, the lifetime of this is really tied to the memory lifetime backing the WasmAlloc
170+
pub fn wasm_slice(&self) -> WasmSlice {
171+
self.wasm_slice
172+
}
173+
}
174+
175+
impl<'w> Deref for WasmSliceWrapper<'w> {
176+
type Target = WasmSlice;
177+
178+
fn deref(&self) -> &WasmSlice {
179+
&self.wasm_slice
180+
}
181+
}
182+
183+
impl<'w> Drop for WasmSliceWrapper<'w> {
184+
fn drop(&mut self) {
185+
if let Err(err) = self.wasm_alloc.dealloc_bytes(self.wasm_slice) {
186+
warn!("Error deallocating bytes: {}", err);
187+
}
188+
}
189+
}

wasmtime-jni/src/wasm_function.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
8181
let func = move |caller: Caller, inputs: &[Val], outputs: &mut [Val]| -> Result<(), Trap> {
8282
let java_args = &java_args;
8383
let java_ret = &java_ret;
84-
let mut wasm_alloc = WasmAlloc::from_caller(&caller);
84+
let wasm_alloc = WasmAlloc::from_caller(&caller);
8585

8686
debug!(
8787
"Calling Java method args {} and return {} with WASM {} inputs and {} outputs",
@@ -144,7 +144,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
144144
for (i, java_arg) in java_args.iter().enumerate() {
145145
let jvalue = unsafe {
146146
java_arg
147-
.load_from_args(&env, &mut input_iter, wasm_alloc.as_mut())
147+
.load_from_args(&env, &mut input_iter, wasm_alloc.as_ref())
148148
.with_context(|| format!("Failed to get Java arg from: {}", java_arg))?
149149
};
150150

@@ -213,8 +213,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
213213
debug!("allocating space and associating bytes for return by ref");
214214
let ptr = ret_by_ref_ptr
215215
.ok_or_else(|| anyhow!("expected return by ref argument pointer"))?;
216-
let mut wasm_alloc =
217-
wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
216+
let wasm_alloc = wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
218217

219218
let bytes = env
220219
.get_direct_buffer_address(val)
@@ -224,15 +223,14 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
224223
unsafe {
225224
let bytes = wasm_alloc.alloc_bytes(bytes)?;
226225
let ret_by_ref_loc = wasm_alloc.obj_as_mut::<WasmSlice>(ptr);
227-
*ret_by_ref_loc = bytes;
226+
*ret_by_ref_loc = bytes.wasm_slice();
228227
}
229228
}
230229
(Some(WasmVal::ByteArray { jarray, .. }), None) => {
231230
debug!("allocating space and associating bytes for return by ref");
232231
let ptr = ret_by_ref_ptr
233232
.ok_or_else(|| anyhow!("expected return by ref argument pointer"))?;
234-
let mut wasm_alloc =
235-
wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
233+
let wasm_alloc = wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
236234

237235
let len = env
238236
.get_array_length(jarray)
@@ -256,15 +254,14 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
256254
// get mutable reference to the return by ref pointer and then store
257255
unsafe {
258256
let ret_by_ref_loc = wasm_alloc.obj_as_mut::<WasmSlice>(ptr);
259-
*ret_by_ref_loc = bytes;
257+
*ret_by_ref_loc = bytes.wasm_slice();
260258
}
261259
}
262260
(Some(WasmVal::String(string)), None) => {
263261
debug!("allocating space and associating string for return by ref");
264262
let ptr = ret_by_ref_ptr
265263
.ok_or_else(|| anyhow!("expected return by ref argument pointer"))?;
266-
let mut wasm_alloc =
267-
wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
264+
let wasm_alloc = wasm_alloc.ok_or_else(|| anyhow!("WasmAlloc is required"))?;
268265

269266
let jstr = env
270267
.get_string(string)
@@ -278,7 +275,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_createFunc<'j>(
278275
// get mutable reference to the return by ref pointer and then store
279276
unsafe {
280277
let ret_by_ref_loc = wasm_alloc.obj_as_mut::<WasmSlice>(ptr);
281-
*ret_by_ref_loc = wasm_slice;
278+
*ret_by_ref_loc = wasm_slice.wasm_slice();
282279
}
283280
}
284281
(Some(WasmVal::ByteBuffer(_val)), Some(_result)) => {
@@ -366,12 +363,16 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_callNtv<'j>(
366363
let len = usize::try_from(len)?;
367364
let mut wasm_args = Vec::with_capacity(len);
368365

369-
let mut wasm_alloc = if !instance.is_null() {
366+
let wasm_alloc = if !instance.is_null() {
370367
WasmAlloc::from_instance(&instance)
371368
} else {
372369
None
373370
};
374371

372+
// let droppers will cleanup allocated memory in the WASM module after the function call
373+
// or should the callee drop?? hmm...
374+
let mut wasm_droppers = Vec::with_capacity(len);
375+
375376
// we need to convert all the parameters to WASM vals for the call
376377
debug!("got {} args for function", len);
377378
for i in 0..len {
@@ -383,7 +384,11 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_callNtv<'j>(
383384
.with_context(|| format!("failed to convert argument at index: {}", i))?;
384385

385386
debug!("adding arg: {}", val.ty());
386-
val.store_to_args(env, &mut wasm_args, wasm_alloc.as_mut())?;
387+
if let Some(dropper) =
388+
val.store_to_args(env, &mut wasm_args, wasm_alloc.as_ref())?
389+
{
390+
wasm_droppers.push(dropper);
391+
}
387392
}
388393

389394
// now we may need to add a return_by_ref parameter
@@ -393,7 +398,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_callNtv<'j>(
393398
let maybe_ret_by_ref = if let Some(wasm_return_ty) = &wasm_return_ty {
394399
wasm_return_ty
395400
.clone()
396-
.return_or_store_to_arg(&mut wasm_args, wasm_alloc.as_mut())?
401+
.return_or_store_to_arg(&mut wasm_args, wasm_alloc.as_ref())?
397402
} else {
398403
None
399404
};
@@ -417,7 +422,7 @@ pub extern "system" fn Java_net_bluejekyll_wasmtime_WasmFunction_callNtv<'j>(
417422
wasm_return_ty,
418423
val.get(0),
419424
maybe_ret_by_ref,
420-
wasm_alloc.as_mut(),
425+
wasm_alloc.as_ref(),
421426
)?
422427
}
423428
} else {

0 commit comments

Comments
 (0)