Skip to content

Commit 05151ef

Browse files
committed
Address code review comments
1 parent 9387ddc commit 05151ef

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

rust/src/runtime/array.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -221,21 +221,31 @@ impl<'a> Tensor<'a> {
221221
}
222222
}
223223

224-
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<f32> {
225-
type Error = Error;
226-
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<f32>> {
227-
ensure!(
228-
tensor.dtype == DTYPE_FLOAT32,
229-
"Cannot convert Tensor with dtype {:?} to ndarray",
230-
tensor.dtype
231-
);
232-
Ok(ndarray::Array::from_shape_vec(
233-
tensor.shape.iter().map(|s| *s as usize).collect::<Vec<usize>>(),
234-
tensor.to_vec::<f32>(),
235-
)?)
236-
}
224+
/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
225+
macro_rules! impl_ndarray_try_from_tensor {
226+
($type:ty, $dtype:expr) => {
227+
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
228+
type Error = Error;
229+
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
230+
ensure!(
231+
tensor.dtype == $dtype,
232+
"Cannot convert Tensor with dtype {:?} to ndarray",
233+
tensor.dtype
234+
);
235+
Ok(ndarray::Array::from_shape_vec(
236+
tensor.shape.iter().map(|s| *s as usize).collect::<Vec<usize>>(),
237+
tensor.to_vec::<$type>(),
238+
)?)
239+
}
240+
}
241+
};
237242
}
238243

244+
impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
245+
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
246+
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
247+
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);
248+
239249
impl DLTensor {
240250
pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
241251
assert!(!flatten || tensor.is_contiguous());
@@ -299,12 +309,6 @@ impl DataType {
299309
}
300310
}
301311

302-
const DTYPE_FLOAT32: DataType = DataType {
303-
code: DLDataTypeCode_kDLFloat as usize,
304-
bits: 32,
305-
lanes: 1,
306-
};
307-
308312
impl<'a> From<&'a DataType> for DLDataType {
309313
fn from(dtype: &'a DataType) -> Self {
310314
Self {
@@ -315,6 +319,22 @@ impl<'a> From<&'a DataType> for DLDataType {
315319
}
316320
}
317321

322+
macro_rules! make_dtype_const {
323+
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
324+
const $name: DataType = DataType {
325+
code: $code as usize,
326+
bits: $bits,
327+
lanes: $lanes,
328+
};
329+
}
330+
}
331+
332+
make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
333+
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
334+
make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
335+
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
336+
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);
337+
318338
impl Default for DLContext {
319339
fn default() -> Self {
320340
DLContext {

rust/src/runtime/module.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub trait Module {
99
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
1010
}
1111

12-
pub struct SystemLibModule {}
12+
pub struct SystemLibModule;
1313

1414
lazy_static! {
1515
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =

rust/src/runtime/packed_func.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,18 @@ macro_rules! impl_boxed_ret_value {
200200
};
201201
}
202202

203+
impl_prim_ret_value!(i8, 0);
204+
impl_prim_ret_value!(u8, 1);
205+
impl_prim_ret_value!(i16, 0);
206+
impl_prim_ret_value!(u16, 1);
203207
impl_prim_ret_value!(i32, 0);
204208
impl_prim_ret_value!(u32, 1);
205209
impl_prim_ret_value!(f32, 2);
210+
impl_prim_ret_value!(i64, 0);
211+
impl_prim_ret_value!(u64, 1);
212+
impl_prim_ret_value!(f64, 2);
213+
impl_prim_ret_value!(isize, 0);
214+
impl_prim_ret_value!(usize, 1);
206215
impl_boxed_ret_value!(String, 11);
207216

208217
// @see `WrapPackedFunc` in `llvm_module.cc`.

0 commit comments

Comments
 (0)