@@ -221,21 +221,31 @@ impl<'a> Tensor<'a> {
221
221
}
222
222
}
223
223
224
- impl < ' a , ' t > TryFrom < & ' a Tensor < ' t > > for ndarray:: ArrayD < f32 > {
225
- type Error = Error ;
226
- fn try_from ( tensor : & ' a Tensor ) -> Result < ndarray:: ArrayD < f32 > > {
227
- ensure ! (
228
- tensor. dtype == DTYPE_FLOAT32 ,
229
- "Cannot convert Tensor with dtype {:?} to ndarray" ,
230
- tensor. dtype
231
- ) ;
232
- Ok ( ndarray:: Array :: from_shape_vec (
233
- tensor. shape . iter ( ) . map ( |s| * s as usize ) . collect :: < Vec < usize > > ( ) ,
234
- tensor. to_vec :: < f32 > ( ) ,
235
- ) ?)
236
- }
224
+ /// Conversions to `ndarray::Array` from `Tensor`, if the types match.
225
+ macro_rules! impl_ndarray_try_from_tensor {
226
+ ( $type: ty, $dtype: expr) => {
227
+ impl <' a, ' t> TryFrom <& ' a Tensor <' t>> for ndarray:: ArrayD <$type> {
228
+ type Error = Error ;
229
+ fn try_from( tensor: & ' a Tensor ) -> Result <ndarray:: ArrayD <$type>> {
230
+ ensure!(
231
+ tensor. dtype == $dtype,
232
+ "Cannot convert Tensor with dtype {:?} to ndarray" ,
233
+ tensor. dtype
234
+ ) ;
235
+ Ok ( ndarray:: Array :: from_shape_vec(
236
+ tensor. shape. iter( ) . map( |s| * s as usize ) . collect:: <Vec <usize >>( ) ,
237
+ tensor. to_vec:: <$type>( ) ,
238
+ ) ?)
239
+ }
240
+ }
241
+ } ;
237
242
}
238
243
244
+ impl_ndarray_try_from_tensor ! ( i32 , DTYPE_INT32 ) ;
245
+ impl_ndarray_try_from_tensor ! ( u32 , DTYPE_UINT32 ) ;
246
+ impl_ndarray_try_from_tensor ! ( f32 , DTYPE_FLOAT32 ) ;
247
+ impl_ndarray_try_from_tensor ! ( f64 , DTYPE_FLOAT64 ) ;
248
+
239
249
impl DLTensor {
240
250
pub ( super ) fn from_tensor < ' a > ( tensor : & ' a Tensor , flatten : bool ) -> Self {
241
251
assert ! ( !flatten || tensor. is_contiguous( ) ) ;
@@ -299,12 +309,6 @@ impl DataType {
299
309
}
300
310
}
301
311
302
- const DTYPE_FLOAT32 : DataType = DataType {
303
- code : DLDataTypeCode_kDLFloat as usize ,
304
- bits : 32 ,
305
- lanes : 1 ,
306
- } ;
307
-
308
312
impl < ' a > From < & ' a DataType > for DLDataType {
309
313
fn from ( dtype : & ' a DataType ) -> Self {
310
314
Self {
@@ -315,6 +319,22 @@ impl<'a> From<&'a DataType> for DLDataType {
315
319
}
316
320
}
317
321
322
+ macro_rules! make_dtype_const {
323
+ ( $name: ident, $code: ident, $bits: expr, $lanes: expr) => {
324
+ const $name: DataType = DataType {
325
+ code: $code as usize ,
326
+ bits: $bits,
327
+ lanes: $lanes,
328
+ } ;
329
+ }
330
+ }
331
+
332
+ make_dtype_const ! ( DTYPE_INT32 , DLDataTypeCode_kDLInt , 32 , 1 ) ;
333
+ make_dtype_const ! ( DTYPE_UINT32 , DLDataTypeCode_kDLUInt , 32 , 1 ) ;
334
+ make_dtype_const ! ( DTYPE_FLOAT16 , DLDataTypeCode_kDLFloat , 16 , 1 ) ;
335
+ make_dtype_const ! ( DTYPE_FLOAT32 , DLDataTypeCode_kDLFloat , 32 , 1 ) ;
336
+ make_dtype_const ! ( DTYPE_FLOAT64 , DLDataTypeCode_kDLFloat , 64 , 1 ) ;
337
+
318
338
impl Default for DLContext {
319
339
fn default ( ) -> Self {
320
340
DLContext {
0 commit comments