Skip to content

Commit 717d025

Browse files
committed
Use failure in tvm_runtime
1 parent 1412c2f commit 717d025

File tree

10 files changed

+114
-112
lines changed

10 files changed

+114
-112
lines changed

rust/runtime/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ sgx = ["nom/alloc"]
1515

1616
[dependencies]
1717
bounded-spsc-queue = "0.4.0"
18-
error-chain = { version = "0.12.0", default-features = false }
18+
failure = "0.1.5"
1919
itertools = "0.7.8"
2020
lazy_static = "1.1.0"
2121
ndarray="0.12.1"

rust/runtime/src/allocator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use alloc::alloc::{self, Layout};
33
#[cfg(not(target_env = "sgx"))]
44
use std::alloc::{self, Layout};
55

6-
use crate::errors::*;
6+
use failure::Error;
77

88
const DEFAULT_ALIGN_BYTES: usize = 4;
99

@@ -15,7 +15,7 @@ pub struct Allocation {
1515

1616
impl Allocation {
1717
/// Allocates a chunk of memory of `size` bytes with optional alignment.
18-
pub fn new(size: usize, align: Option<usize>) -> Result<Self> {
18+
pub fn new(size: usize, align: Option<usize>) -> Result<Self, Error> {
1919
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES);
2020
let layout = Layout::from_size_align(size, alignment)?;
2121
let ptr = unsafe { alloc::alloc(layout.clone()) };

rust/runtime/src/array.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice};
22

3+
use failure::Error;
34
use ndarray;
45
use tvm_common::{
56
array::{DataType, TVMContext},
@@ -9,7 +10,7 @@ use tvm_common::{
910
},
1011
};
1112

12-
use crate::{allocator::Allocation, errors::*};
13+
use crate::allocator::Allocation;
1314

1415
/// A `Storage` is a container which holds `Tensor` data.
1516
#[derive(PartialEq)]
@@ -22,7 +23,7 @@ pub enum Storage<'a> {
2223
}
2324

2425
impl<'a> Storage<'a> {
25-
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>> {
26+
pub fn new(size: usize, align: Option<usize>) -> Result<Storage<'static>, Error> {
2627
Ok(Storage::Owned(Allocation::new(size, align)?))
2728
}
2829

@@ -258,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor {
258259
($type:ty, $dtype:expr) => {
259260
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
260261
type Error = Error;
261-
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
262+
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>, Error> {
262263
ensure!(
263264
tensor.dtype == $dtype,
264265
"Cannot convert Tensor with dtype {:?} to ndarray",

rust/runtime/src/errors.rs

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,49 @@
1-
#[cfg(target_env = "sgx")]
2-
use alloc::alloc;
3-
#[cfg(not(target_env = "sgx"))]
4-
use std::alloc;
5-
use std::num;
6-
7-
use ndarray;
8-
use serde_json;
9-
10-
error_chain! {
11-
errors {
12-
GraphFormatError(msg: String) {
13-
description("unable to load graph")
14-
display("could not load graph json: {}", msg)
15-
}
16-
17-
LoadGraphParamsError(msg: String) {
18-
description("unable to load graph params")
19-
display("could not load graph params: {}", msg)
20-
}
21-
}
22-
foreign_links {
23-
Alloc(alloc::AllocErr);
24-
GraphDeserialize(serde_json::Error);
25-
ParseInt(num::ParseIntError);
26-
ShapeError(ndarray::ShapeError);
27-
CommonError(tvm_common::errors::Error);
28-
}
1+
#[derive(Debug, Fail)]
2+
pub enum GraphFormatError {
3+
#[fail(display = "Could not parse graph json")]
4+
Parse(#[fail(cause)] failure::Error),
5+
#[fail(display = "Could not parse graph params")]
6+
Params,
7+
#[fail(display = "{} is missing attr: {}", 0, 1)]
8+
MissingAttr(String, String),
9+
#[fail(display = "Missing field: {}", 0)]
10+
MissingField(&'static str),
11+
#[fail(display = "Invalid DLType: {}", 0)]
12+
InvalidDLType(String),
2913
}
3014

31-
impl From<alloc::LayoutErr> for Error {
32-
fn from(_err: alloc::LayoutErr) -> Error {
33-
Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
34-
}
35-
}
15+
// #[cfg(target_env = "sgx")]
16+
// use alloc::alloc;
17+
// #[cfg(not(target_env = "sgx"))]
18+
// use std::alloc;
19+
// use std::num;
20+
//
21+
// use ndarray;
22+
// use serde_json;
23+
24+
// error_chain! {
25+
// errors {
26+
// GraphFormatError(msg: String) {
27+
// description("unable to load graph")
28+
// display("could not load graph json: {}", msg)
29+
// }
30+
//
31+
// LoadGraphParamsError(msg: String) {
32+
// description("unable to load graph params")
33+
// display("could not load graph params: {}", msg)
34+
// }
35+
// }
36+
// foreign_links {
37+
// Alloc(alloc::AllocErr);
38+
// GraphDeserialize(serde_json::Error);
39+
// ParseInt(num::ParseIntError);
40+
// ShapeError(ndarray::ShapeError);
41+
// CommonError(tvm_common::errors::Error);
42+
// }
43+
// }
44+
//
45+
// impl From<alloc::LayoutErr> for Error {
46+
// fn from(_err: alloc::LayoutErr) -> Error {
47+
// Error::from_kind(ErrorKind::Msg("Layout error".to_string()))
48+
// }
49+
// }

rust/runtime/src/graph.rs

Lines changed: 43 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
22

3+
use failure::Error;
34
use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
45
use serde;
56
use serde_json;
6-
7-
use crate::{
8-
errors::{Error, ErrorKind, Result},
9-
Module, Storage, Tensor,
10-
};
117
use tvm_common::{
128
array::{DataType, TVMContext},
139
ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
1410
TVMArgValue,
1511
};
1612

13+
use crate::{errors::GraphFormatError, Module, Storage, Tensor};
14+
1715
// @see `kTVMNDArrayMagic` in `ndarray.h`
1816
const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
1917
// @see `kTVMNDArrayListMagic` in `graph_runtime.h`
@@ -44,28 +42,26 @@ pub struct Entry {
4442
}
4543

4644
impl Graph {
47-
fn entry_index(&self, entry: &Entry) -> Result<usize> {
45+
fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
4846
self.node_row_ptr
4947
.as_ref()
5048
.map(|nrp| nrp[entry.id] + entry.index)
51-
.ok_or("Missing node_row_ptr.".into())
49+
.ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
5250
}
5351

5452
/// Attempt to deserialize a JSON attribute to a type `T`.
55-
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T> {
53+
fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
5654
Ok(serde_json::from_value::<T>(
5755
self.attrs
5856
.as_ref()
59-
.ok_or(ErrorKind::GraphFormatError(
60-
"Missing graph attrs".to_string(),
61-
))?
57+
.ok_or(GraphFormatError::MissingField("attrs"))?
6258
.get(attr)
63-
.ok_or(ErrorKind::GraphFormatError(format!(
64-
"Missing {} attr",
65-
attr
66-
)))?
59+
.ok_or_else(|| {
60+
GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
61+
})?
6762
.to_owned(),
68-
)?)
63+
)
64+
.map_err(|err| GraphFormatError::Parse(err.into()))?)
6965
}
7066
}
7167

@@ -84,47 +80,39 @@ struct NodeAttrs {
8480
flatten_data: bool,
8581
}
8682

83+
macro_rules! get_node_attr {
84+
($node:expr, $attrs:ident, $attr:literal) => {
85+
$attrs
86+
.get($attr)
87+
.ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
88+
};
89+
}
90+
8791
impl Node {
88-
fn parse_attrs(&self) -> Result<NodeAttrs> {
92+
fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
8993
let attrs = self
9094
.attrs
9195
.as_ref()
92-
.ok_or(format!("Missing node.attrs for `{}`", self.name))?;
93-
let func_name = attrs
94-
.get("func_name")
95-
.ok_or(format!("Node `{}` is missing attrs.func_name", self.name))?
96-
.to_string();
97-
let num_outputs = attrs
98-
.get("num_outputs")
99-
.ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))?
100-
.parse::<usize>()?;
101-
let flatten_data = attrs
102-
.get("flatten_data")
103-
.ok_or(format!(
104-
"Node `{}` is missing attrs.flatten_data",
105-
self.name
106-
))?
107-
.parse::<u8>()?
108-
== 1;
96+
.ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
10997
Ok(NodeAttrs {
110-
func_name,
111-
num_outputs,
112-
flatten_data,
98+
func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
99+
num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
100+
flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
113101
})
114102
}
115103
}
116104

117105
impl<'a> TryFrom<&'a String> for Graph {
118106
type Error = Error;
119-
fn try_from(graph_json: &String) -> Result<Self> {
107+
fn try_from(graph_json: &String) -> Result<Self, self::Error> {
120108
let graph = serde_json::from_str(graph_json)?;
121109
Ok(graph)
122110
}
123111
}
124112

125113
impl<'a> TryFrom<&'a str> for Graph {
126114
type Error = Error;
127-
fn try_from(graph_json: &'a str) -> Result<Self> {
115+
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
128116
let graph = serde_json::from_str(graph_json)?;
129117
Ok(graph)
130118
}
@@ -164,7 +152,7 @@ pub struct GraphExecutor<'m, 't> {
164152
unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
165153

166154
impl<'m, 't> GraphExecutor<'m, 't> {
167-
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
155+
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
168156
let tensors = Self::setup_storages(&graph)?;
169157
Ok(GraphExecutor {
170158
op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
@@ -181,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
181169
}
182170

183171
/// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
184-
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>> {
172+
fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
185173
let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
186174
let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
187175
let dtypes = graph
@@ -192,13 +180,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
192180
if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
193181
Ok(dtype)
194182
} else {
195-
Err(ErrorKind::GraphFormatError(
196-
format!("Invalid dltype: {}", dltype).to_string(),
197-
)
198-
.into())
183+
Err(GraphFormatError::InvalidDLType(dltype.to_string()))
199184
}
200185
})
201-
.collect::<Result<Vec<DataType>>>()?;
186+
.collect::<Result<Vec<DataType>, GraphFormatError>>()?;
202187

203188
let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
204189
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
@@ -211,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
211196
let mut storages: Vec<Storage> = storage_num_bytes
212197
.into_iter()
213198
.map(|nbytes| Storage::new(nbytes, align))
214-
.collect::<Result<Vec<Storage>>>()?;
199+
.collect::<Result<Vec<Storage>, Error>>()?;
215200

216201
let tensors = izip!(storage_ids, shapes, dtypes)
217202
.map(|(storage_id, shape, dtype)| {
@@ -236,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
236221
graph: &Graph,
237222
lib: &'m M,
238223
tensors: &Vec<Tensor<'t>>,
239-
) -> Result<Vec<Box<Fn() + 'm>>> {
224+
) -> Result<Vec<Box<Fn() + 'm>>, Error> {
240225
ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
241226
let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
242227

@@ -254,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> {
254239
continue;
255240
}
256241

257-
let func = lib
258-
.get_function(&attrs.func_name)
259-
.ok_or(format!("Missing function {}", attrs.func_name))?;
242+
let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
243+
"Library is missing function {}",
244+
attrs.func_name
245+
))?;
260246
let arg_indices = node
261247
.inputs
262248
.iter()
@@ -272,7 +258,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
272258
DLTensor::from(tensor)
273259
})
274260
})
275-
.collect::<Result<Vec<DLTensor>>>()
261+
.collect::<Result<Vec<DLTensor>, Error>>()
276262
.unwrap();
277263
let op: Box<Fn()> = box move || {
278264
let args = dl_tensors
@@ -436,17 +422,15 @@ named!(
436422
);
437423

438424
/// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
439-
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>> {
425+
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
440426
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
441-
if remaining_bytes.len() > 0 {
442-
bail!(ErrorKind::LoadGraphParamsError("extra input".to_string()))
443-
} else {
427+
if remaining_bytes.len() == 0 {
444428
Ok(param_dict)
429+
} else {
430+
Err(GraphFormatError::Params)
445431
}
446432
} else {
447-
bail!(ErrorKind::LoadGraphParamsError(
448-
"invalid parameters file".to_string()
449-
))
433+
Err(GraphFormatError::Params)
450434
}
451435
}
452436

rust/runtime/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ extern crate bounded_spsc_queue;
2525
#[cfg(target_env = "sgx")]
2626
extern crate core;
2727
#[macro_use]
28-
extern crate error_chain;
28+
extern crate failure;
2929
#[macro_use]
3030
extern crate itertools;
3131
#[macro_use]

0 commit comments

Comments
 (0)