1
1
use std:: { cmp, collections:: HashMap , convert:: TryFrom , iter:: FromIterator , mem, str} ;
2
2
3
+ use failure:: Error ;
3
4
use nom:: { alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types:: CompleteStr } ;
4
5
use serde;
5
6
use serde_json;
6
-
7
- use crate :: {
8
- errors:: { Error , ErrorKind , Result } ,
9
- Module , Storage , Tensor ,
10
- } ;
11
7
use tvm_common:: {
12
8
array:: { DataType , TVMContext } ,
13
9
ffi:: { DLDataTypeCode_kDLFloat , DLDataTypeCode_kDLInt , DLDataTypeCode_kDLUInt , DLTensor } ,
14
10
TVMArgValue ,
15
11
} ;
16
12
13
+ use crate :: { errors:: GraphFormatError , Module , Storage , Tensor } ;
14
+
17
15
// @see `kTVMNDArrayMagic` in `ndarray.h`
18
16
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F ;
19
17
// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
@@ -44,28 +42,26 @@ pub struct Entry {
44
42
}
45
43
46
44
impl Graph {
47
- fn entry_index ( & self , entry : & Entry ) -> Result < usize > {
45
+ fn entry_index ( & self , entry : & Entry ) -> Result < usize , GraphFormatError > {
48
46
self . node_row_ptr
49
47
. as_ref ( )
50
48
. map ( |nrp| nrp[ entry. id ] + entry. index )
51
- . ok_or ( "Missing node_row_ptr." . into ( ) )
49
+ . ok_or_else ( || GraphFormatError :: MissingField ( " node_row_ptr" ) )
52
50
}
53
51
54
52
/// Attempt to deserialize a JSON attribute to a type `T`.
55
- fn get_attr < T : serde:: de:: DeserializeOwned > ( & self , attr : & str ) -> Result < T > {
53
+ fn get_attr < T : serde:: de:: DeserializeOwned > ( & self , attr : & str ) -> Result < T , GraphFormatError > {
56
54
Ok ( serde_json:: from_value :: < T > (
57
55
self . attrs
58
56
. as_ref ( )
59
- . ok_or ( ErrorKind :: GraphFormatError (
60
- "Missing graph attrs" . to_string ( ) ,
61
- ) ) ?
57
+ . ok_or ( GraphFormatError :: MissingField ( "attrs" ) ) ?
62
58
. get ( attr)
63
- . ok_or ( ErrorKind :: GraphFormatError ( format ! (
64
- "Missing {} attr" ,
65
- attr
66
- ) ) ) ?
59
+ . ok_or_else ( || {
60
+ GraphFormatError :: MissingAttr ( "graph" . to_string ( ) , attr. to_string ( ) )
61
+ } ) ?
67
62
. to_owned ( ) ,
68
- ) ?)
63
+ )
64
+ . map_err ( |err| GraphFormatError :: Parse ( err. into ( ) ) ) ?)
69
65
}
70
66
}
71
67
@@ -84,47 +80,39 @@ struct NodeAttrs {
84
80
flatten_data : bool ,
85
81
}
86
82
83
+ macro_rules! get_node_attr {
84
+ ( $node: expr, $attrs: ident, $attr: literal) => {
85
+ $attrs
86
+ . get( $attr)
87
+ . ok_or_else( || GraphFormatError :: MissingAttr ( $node. to_owned( ) , $attr. to_owned( ) ) )
88
+ } ;
89
+ }
90
+
87
91
impl Node {
88
- fn parse_attrs ( & self ) -> Result < NodeAttrs > {
92
+ fn parse_attrs ( & self ) -> Result < NodeAttrs , Error > {
89
93
let attrs = self
90
94
. attrs
91
95
. as_ref ( )
92
- . ok_or ( format ! ( "Missing node.attrs for `{}`" , self . name) ) ?;
93
- let func_name = attrs
94
- . get ( "func_name" )
95
- . ok_or ( format ! ( "Node `{}` is missing attrs.func_name" , self . name) ) ?
96
- . to_string ( ) ;
97
- let num_outputs = attrs
98
- . get ( "num_outputs" )
99
- . ok_or ( format ! ( "Node `{}` is missing attrs.num_outputs" , self . name) ) ?
100
- . parse :: < usize > ( ) ?;
101
- let flatten_data = attrs
102
- . get ( "flatten_data" )
103
- . ok_or ( format ! (
104
- "Node `{}` is missing attrs.flatten_data" ,
105
- self . name
106
- ) ) ?
107
- . parse :: < u8 > ( ) ?
108
- == 1 ;
96
+ . ok_or_else ( || GraphFormatError :: MissingAttr ( self . name . clone ( ) , "attrs" . to_owned ( ) ) ) ?;
109
97
Ok ( NodeAttrs {
110
- func_name,
111
- num_outputs,
112
- flatten_data,
98
+ func_name : get_node_attr ! ( self . name , attrs , "func_name" ) ? . to_owned ( ) ,
99
+ num_outputs : get_node_attr ! ( self . name , attrs , "num_outputs" ) ? . parse :: < usize > ( ) ? ,
100
+ flatten_data : get_node_attr ! ( self . name , attrs , "flatten_data" ) ? . parse :: < u8 > ( ) ? == 1 ,
113
101
} )
114
102
}
115
103
}
116
104
117
105
impl < ' a > TryFrom < & ' a String > for Graph {
118
106
type Error = Error ;
119
- fn try_from ( graph_json : & String ) -> Result < Self > {
107
+ fn try_from ( graph_json : & String ) -> Result < Self , self :: Error > {
120
108
let graph = serde_json:: from_str ( graph_json) ?;
121
109
Ok ( graph)
122
110
}
123
111
}
124
112
125
113
impl < ' a > TryFrom < & ' a str > for Graph {
126
114
type Error = Error ;
127
- fn try_from ( graph_json : & ' a str ) -> Result < Self > {
115
+ fn try_from ( graph_json : & ' a str ) -> Result < Self , Self :: Error > {
128
116
let graph = serde_json:: from_str ( graph_json) ?;
129
117
Ok ( graph)
130
118
}
@@ -164,7 +152,7 @@ pub struct GraphExecutor<'m, 't> {
164
152
unsafe impl < ' m , ' t > Send for GraphExecutor < ' m , ' t > { }
165
153
166
154
impl < ' m , ' t > GraphExecutor < ' m , ' t > {
167
- pub fn new < M : ' m + Module > ( graph : Graph , lib : & ' m M ) -> Result < Self > {
155
+ pub fn new < M : ' m + Module > ( graph : Graph , lib : & ' m M ) -> Result < Self , Error > {
168
156
let tensors = Self :: setup_storages ( & graph) ?;
169
157
Ok ( GraphExecutor {
170
158
op_execs : Self :: setup_op_execs ( & graph, lib, & tensors) ?,
@@ -181,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
181
169
}
182
170
183
171
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
184
- fn setup_storages < ' a > ( graph : & ' a Graph ) -> Result < Vec < Tensor < ' t > > > {
172
+ fn setup_storages < ' a > ( graph : & ' a Graph ) -> Result < Vec < Tensor < ' t > > , Error > {
185
173
let storage_ids = graph. get_attr :: < ( String , Vec < usize > ) > ( "storage_id" ) ?. 1 ;
186
174
let shapes = graph. get_attr :: < ( String , Vec < Vec < i64 > > ) > ( "shape" ) ?. 1 ;
187
175
let dtypes = graph
@@ -192,13 +180,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
192
180
if let Ok ( ( _, dtype) ) = tvm_str_to_type ( CompleteStr ( dltype) ) {
193
181
Ok ( dtype)
194
182
} else {
195
- Err ( ErrorKind :: GraphFormatError (
196
- format ! ( "Invalid dltype: {}" , dltype) . to_string ( ) ,
197
- )
198
- . into ( ) )
183
+ Err ( GraphFormatError :: InvalidDLType ( dltype. to_string ( ) ) )
199
184
}
200
185
} )
201
- . collect :: < Result < Vec < DataType > > > ( ) ?;
186
+ . collect :: < Result < Vec < DataType > , GraphFormatError > > ( ) ?;
202
187
203
188
let align = dtypes. iter ( ) . map ( |dtype| dtype. bits ( ) as usize ) . max ( ) ;
204
189
let mut storage_num_bytes = vec ! [ 0usize ; * storage_ids. iter( ) . max( ) . unwrap_or( & 1 ) + 1 ] ;
@@ -211,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
211
196
let mut storages: Vec < Storage > = storage_num_bytes
212
197
. into_iter ( )
213
198
. map ( |nbytes| Storage :: new ( nbytes, align) )
214
- . collect :: < Result < Vec < Storage > > > ( ) ?;
199
+ . collect :: < Result < Vec < Storage > , Error > > ( ) ?;
215
200
216
201
let tensors = izip ! ( storage_ids, shapes, dtypes)
217
202
. map ( |( storage_id, shape, dtype) | {
@@ -236,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
236
221
graph : & Graph ,
237
222
lib : & ' m M ,
238
223
tensors : & Vec < Tensor < ' t > > ,
239
- ) -> Result < Vec < Box < Fn ( ) + ' m > > > {
224
+ ) -> Result < Vec < Box < Fn ( ) + ' m > > , Error > {
240
225
ensure ! ( graph. node_row_ptr. is_some( ) , "Missing node_row_ptr." ) ;
241
226
let node_row_ptr = graph. node_row_ptr . as_ref ( ) . unwrap ( ) ;
242
227
@@ -254,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
254
239
continue ;
255
240
}
256
241
257
- let func = lib
258
- . get_function ( & attrs. func_name )
259
- . ok_or ( format ! ( "Missing function {}" , attrs. func_name) ) ?;
242
+ let func = lib. get_function ( & attrs. func_name ) . ok_or ( format_err ! (
243
+ "Library is missing function {}" ,
244
+ attrs. func_name
245
+ ) ) ?;
260
246
let arg_indices = node
261
247
. inputs
262
248
. iter ( )
@@ -272,7 +258,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
272
258
DLTensor :: from ( tensor)
273
259
} )
274
260
} )
275
- . collect :: < Result < Vec < DLTensor > > > ( )
261
+ . collect :: < Result < Vec < DLTensor > , Error > > ( )
276
262
. unwrap ( ) ;
277
263
let op: Box < Fn ( ) > = box move || {
278
264
let args = dl_tensors
@@ -436,17 +422,15 @@ named!(
436
422
) ;
437
423
438
424
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
439
- pub fn load_param_dict ( bytes : & [ u8 ] ) -> Result < HashMap < String , Tensor > > {
425
+ pub fn load_param_dict ( bytes : & [ u8 ] ) -> Result < HashMap < String , Tensor > , GraphFormatError > {
440
426
if let Ok ( ( remaining_bytes, param_dict) ) = parse_param_dict ( bytes) {
441
- if remaining_bytes. len ( ) > 0 {
442
- bail ! ( ErrorKind :: LoadGraphParamsError ( "extra input" . to_string( ) ) )
443
- } else {
427
+ if remaining_bytes. len ( ) == 0 {
444
428
Ok ( param_dict)
429
+ } else {
430
+ Err ( GraphFormatError :: Params )
445
431
}
446
432
} else {
447
- bail ! ( ErrorKind :: LoadGraphParamsError (
448
- "invalid parameters file" . to_string( )
449
- ) )
433
+ Err ( GraphFormatError :: Params )
450
434
}
451
435
}
452
436
0 commit comments