Skip to content

Commit 88aff7c

Browse files
committed
Updates
1 parent a0d3d41 commit 88aff7c

File tree

6 files changed

+35
-9
lines changed

6 files changed

+35
-9
lines changed

rust/src/runtime/array.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ pub struct Tensor<'a> {
129129
pub(super) size: usize,
130130
}
131131

132+
unsafe impl<'a> Send for Tensor<'a> {}
133+
132134
impl<'a> Tensor<'a> {
133135
pub fn shape(&self) -> Vec<i64> {
134136
self.shape.clone()

rust/src/runtime/graph.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ pub struct GraphExecutor<'m, 't> {
156156
tensors: Vec<Tensor<'t>>,
157157
}
158158

159+
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
160+
159161
impl<'m, 't> GraphExecutor<'m, 't> {
160162
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
161163
let tensors = Self::setup_storages(&graph)?;
@@ -189,7 +191,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
189191
}
190192
}).collect::<Result<Vec<DataType>>>()?;
191193

192-
let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max();
194+
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
193195
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
194196
for (i, &storage_id) in storage_ids.iter().enumerate() {
195197
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;

rust/src/runtime/packed_func.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ pub struct TVMRetValue {
149149
type_code: i64,
150150
}
151151

152+
#[cfg(target_env = "sgx")]
152153
impl TVMRetValue {
153-
#[cfg(target_env = "sgx")]
154154
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
155155
unsafe {
156156
Self {
@@ -166,6 +166,25 @@ impl TVMRetValue {
166166
}
167167
}
168168
}
169+
170+
pub fn into_tvm_value(self) -> (TVMValue, i64) {
171+
let val = match self.type_code {
172+
0 | 1 => TVMValue {
173+
v_int64: self.prim_value.clone() as i64,
174+
},
175+
2 => TVMValue {
176+
v_float64: self.prim_value.clone() as f64,
177+
},
178+
3 | 7 | 8 | 9 | 10 => TVMValue {
179+
v_handle: Box::into_raw(self.box_value) as *mut c_void,
180+
},
181+
11 | 12 => TVMValue {
182+
v_str: Box::into_raw(self.box_value) as *const _,
183+
},
184+
_ => unreachable!(),
185+
};
186+
(val, self.type_code)
187+
}
169188
}
170189

171190
impl Default for TVMRetValue {

rust/src/runtime/sgx.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ macro_rules! ocall_packed {
6969
}
7070
}
7171

72+
pub fn shutdown() {
73+
if env!("TVM_NUM_THREADS") != "0" {
74+
sgx_join_threads()
75+
}
76+
}
77+
7278
impl Drop for SystemLibModule {
7379
fn drop(&mut self) {
74-
if env!("TVM_NUM_THREADS") != "0" {
75-
sgx_join_threads()
76-
}
80+
shutdown()
7781
}
7882
}

rust/src/runtime/threading.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
257257
}
258258

259259
#[cfg(target_env = "sgx")]
260-
pub(crate) fn sgx_join_threads() -> () {
260+
pub(crate) fn sgx_join_threads() {
261261
extern "C" fn poison_pill(
262262
_task_id: usize,
263263
_penv: *const TVMParallelGroupEnv,

rust/src/runtime/workspace.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ impl WorkspacePool {
4242
if !ws_size >= size {
4343
return cur_ws_idx;
4444
}
45-
cur_ws_idx.and_then(|cur_idx| {
45+
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
4646
let cur_size = self.workspaces[cur_idx].size();
47-
Some(match ws_size < cur_size {
48-
// is already ok
47+
Some(match ws_size <= cur_size {
4948
true => idx,
5049
false => cur_idx,
5150
})

0 commit comments

Comments
 (0)