Skip to content

Commit e2970b2

Browse files
ehsanmoknhynes
authored andcommitted
[RUST][FRONTEND] Add rust frontend v0.1 (#2292)
1 parent 18b2eba commit e2970b2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+5642
-2275
lines changed

rust/.rustfmt.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
max_width = 100
22
hard_tabs = false
3-
tab_spaces = 2
3+
tab_spaces = 4
44
newline_style = "Auto"
55
use_small_heuristics = "Default"
66
indent_style = "Block"
@@ -38,7 +38,7 @@ trailing_comma = "Vertical"
3838
match_block_trailing_comma = false
3939
blank_lines_upper_bound = 1
4040
blank_lines_lower_bound = 0
41-
edition = "2015"
41+
edition = "2018"
4242
merge_derives = true
4343
use_try_shorthand = true
4444
use_field_init_shorthand = false
@@ -50,8 +50,8 @@ unstable_features = false
5050
disable_all_formatting = false
5151
skip_children = false
5252
hide_parse_errors = false
53-
error_on_line_overflow = false
54-
error_on_unformatted = false
53+
error_on_line_overflow = true
54+
error_on_unformatted = true
5555
report_todo = "Never"
5656
report_fixme = "Never"
5757
ignore = []

rust/Cargo.toml

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,11 @@
1-
[package]
2-
name = "tvm"
3-
version = "0.1.0"
4-
license = "Apache-2.0"
5-
description = "TVM Rust runtime"
6-
repository = "https://github.com/dmlc/tvm"
7-
readme = "README.md"
8-
keywords = ["tvm", "nnvm"]
9-
categories = ["api-bindings", "science"]
10-
authors = ["TVM Contributors"]
11-
12-
[features]
13-
default = ["nom/std"]
14-
sgx = ["nom/alloc"]
15-
16-
[dependencies]
17-
bounded-spsc-queue = "0.4.0"
18-
error-chain = { version = "0.12.0", default-features = false }
19-
itertools = "0.7.8"
20-
lazy_static = "1.1.0"
21-
ndarray = "0.11.2"
22-
nom = {version = "4.0.0", default-features = false }
23-
serde = "1.0.59"
24-
serde_derive = "1.0.79"
25-
serde_json = "1.0.17"
26-
27-
[target.'cfg(not(target_env = "sgx"))'.dependencies]
28-
num_cpus = "1.8.0"
1+
[workspace]
2+
members = [
3+
"common",
4+
"runtime",
5+
"runtime/tests/test_tvm_basic",
6+
"runtime/tests/test_nnvm",
7+
"frontend",
8+
"frontend/tests/basics",
9+
"frontend/tests/callback",
10+
"frontend/examples/resnet"
11+
]

rust/common/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
target
2+
**/*.rs.bk
3+
Cargo.lock
4+
/tvm-sys/src/bindgen.rs

rust/common/Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "tvm-common"
3+
version = "0.1.0"
4+
authors = ["TVM Contributors"]
5+
license = "Apache-2.0"
6+
7+
[features]
8+
runtime = []
9+
frontend = ["tvm-sys"]
10+
11+
[dependencies]
12+
error-chain = { version = "0.12.0", default-features = false }
13+
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
File renamed without changes.

rust/common/src/errors.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.
2+
3+
error_chain! {
4+
errors {
5+
TryFromTVMArgValueError(expected: String, actual: String) {
6+
description("mismatched types while converting from TVMArgValue")
7+
display("expected `{}` but given `{}`", expected, actual)
8+
}
9+
10+
TryFromTVMRetValueError(expected: String, actual: String) {
11+
description("mismatched types while downcasting TVMRetValue")
12+
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
13+
}
14+
}
15+
}

rust/common/src/lib.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//! This crate contains the refactored basic components required
2+
//! for `runtime` and `frontend` TVM crates.
3+
4+
#![crate_name = "tvm_common"]
5+
#![recursion_limit = "1024"]
6+
#![allow(non_camel_case_types, unused_imports)]
7+
#![feature(box_syntax, try_from)]
8+
9+
#[macro_use]
10+
extern crate error_chain;
11+
12+
/// Unified ffi module for both runtime and frontend crates.
13+
pub mod ffi {
14+
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
15+
16+
#[cfg(feature = "frontend")]
17+
pub extern crate tvm_sys as ts;
18+
19+
#[cfg(feature = "runtime")]
20+
pub mod runtime {
21+
use std::os::raw::{c_char, c_int, c_void};
22+
23+
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
24+
25+
pub type BackendPackedCFunc = extern "C" fn(
26+
args: *const TVMValue,
27+
type_codes: *const c_int,
28+
num_args: c_int,
29+
) -> c_int;
30+
}
31+
}
32+
33+
pub mod errors;
34+
pub mod ty;
35+
pub mod value;
36+
37+
pub use errors::*;
38+
pub use ty::TVMTypeCode;
39+
pub use value::{TVMArgValue, TVMRetValue, TVMValue};

rust/common/src/ty.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
2+
//!
3+
//! # Example
4+
//!
5+
//! ```
6+
//! let dtype = TVMType::from("float");
7+
//! println!("dtype is: {}", dtype);
8+
//! ```
9+
10+
use std::{
11+
ffi::{CStr, CString},
12+
fmt::{self, Display, Formatter},
13+
};
14+
15+
/// TVM type codes.
16+
#[repr(u32)]
17+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
18+
pub enum TVMTypeCode {
19+
kDLInt = 0,
20+
kDLUInt = 1,
21+
kDLFloat = 2,
22+
kHandle = 3,
23+
kNull = 4,
24+
kTVMType = 5,
25+
kTVMContext = 6,
26+
kArrayHandle = 7,
27+
kNodeHandle = 8,
28+
kModuleHandle = 9,
29+
kFuncHandle = 10,
30+
kStr = 11,
31+
kBytes = 12,
32+
kNDArrayContainer = 13,
33+
}
34+
35+
impl Default for TVMTypeCode {
36+
fn default() -> Self {
37+
TVMTypeCode::kDLInt
38+
}
39+
}
40+
41+
impl From<TVMTypeCode> for i64 {
42+
fn from(arg: TVMTypeCode) -> i64 {
43+
match arg {
44+
TVMTypeCode::kDLInt => 0,
45+
TVMTypeCode::kDLUInt => 1,
46+
TVMTypeCode::kDLFloat => 2,
47+
TVMTypeCode::kHandle => 3,
48+
TVMTypeCode::kNull => 4,
49+
TVMTypeCode::kTVMType => 5,
50+
TVMTypeCode::kTVMContext => 6,
51+
TVMTypeCode::kArrayHandle => 7,
52+
TVMTypeCode::kNodeHandle => 8,
53+
TVMTypeCode::kModuleHandle => 9,
54+
TVMTypeCode::kFuncHandle => 10,
55+
TVMTypeCode::kStr => 11,
56+
TVMTypeCode::kBytes => 12,
57+
TVMTypeCode::kNDArrayContainer => 13,
58+
}
59+
}
60+
}
61+
62+
impl Into<TVMTypeCode> for i64 {
63+
fn into(self) -> TVMTypeCode {
64+
match self {
65+
0 => TVMTypeCode::kDLInt,
66+
1 => TVMTypeCode::kDLUInt,
67+
2 => TVMTypeCode::kDLFloat,
68+
3 => TVMTypeCode::kHandle,
69+
4 => TVMTypeCode::kNull,
70+
5 => TVMTypeCode::kTVMType,
71+
6 => TVMTypeCode::kTVMContext,
72+
7 => TVMTypeCode::kArrayHandle,
73+
8 => TVMTypeCode::kNodeHandle,
74+
9 => TVMTypeCode::kModuleHandle,
75+
10 => TVMTypeCode::kFuncHandle,
76+
11 => TVMTypeCode::kStr,
77+
12 => TVMTypeCode::kBytes,
78+
13 => TVMTypeCode::kNDArrayContainer,
79+
_ => unreachable!(),
80+
}
81+
}
82+
}
83+
84+
impl Display for TVMTypeCode {
85+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
86+
write!(
87+
f,
88+
"{}",
89+
match self {
90+
TVMTypeCode::kDLInt => "int",
91+
TVMTypeCode::kDLUInt => "uint",
92+
TVMTypeCode::kDLFloat => "float",
93+
TVMTypeCode::kHandle => "handle",
94+
TVMTypeCode::kNull => "null",
95+
TVMTypeCode::kTVMType => "TVM type",
96+
TVMTypeCode::kTVMContext => "TVM context",
97+
TVMTypeCode::kArrayHandle => "Array handle",
98+
TVMTypeCode::kNodeHandle => "Node handle",
99+
TVMTypeCode::kModuleHandle => "Module handle",
100+
TVMTypeCode::kFuncHandle => "Function handle",
101+
TVMTypeCode::kStr => "string",
102+
TVMTypeCode::kBytes => "bytes",
103+
TVMTypeCode::kNDArrayContainer => "ndarray container",
104+
}
105+
)
106+
}
107+
}
108+
109+
macro_rules! impl_prim_type {
110+
($type:ty, $variant:ident) => {
111+
impl<'a> From<&'a $type> for TVMTypeCode {
112+
fn from(_arg: &$type) -> Self {
113+
TVMTypeCode::$variant
114+
}
115+
}
116+
117+
impl<'a> From<&'a mut $type> for TVMTypeCode {
118+
fn from(_arg: &mut $type) -> Self {
119+
TVMTypeCode::$variant
120+
}
121+
}
122+
};
123+
}
124+
125+
impl_prim_type!(usize, kDLInt);
126+
impl_prim_type!(i64, kDLInt);
127+
impl_prim_type!(i32, kDLInt);
128+
impl_prim_type!(i16, kDLInt);
129+
impl_prim_type!(i8, kDLInt);
130+
131+
impl_prim_type!(u64, kDLUInt);
132+
impl_prim_type!(u32, kDLUInt);
133+
impl_prim_type!(u16, kDLUInt);
134+
impl_prim_type!(u8, kDLUInt);
135+
136+
impl_prim_type!(f64, kDLFloat);
137+
impl_prim_type!(f32, kDLFloat);
138+
139+
impl_prim_type!(str, kStr);
140+
impl_prim_type!(CStr, kStr);
141+
impl_prim_type!(String, kStr);
142+
impl_prim_type!(CString, kStr);
143+
144+
impl_prim_type!([u8], kBytes);

0 commit comments

Comments
 (0)