Skip to content

Commit 4bbe0da

Browse files
authored
add more dbtype support for tensor (#468)
1 parent 75ed900 commit 4bbe0da

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/tensor/npy.rs

+16-8
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,23 @@ impl Header {
129129
if descr.starts_with('>') {
130130
return Err(TchError::FileFormat(format!("little-endian descr {}", descr)));
131131
}
132+
// the only supported types in tensor are:
133+
// float64, float32, float16,
134+
// complex64, complex128,
135+
// int64, int32, int16, int8,
136+
// uint8, and bool.
132137
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
133-
"f2" => Kind::Half,
134-
"f4" => Kind::Float,
135-
"f8" => Kind::Double,
136-
"i4" => Kind::Int,
137-
"i8" => Kind::Int64,
138-
"i2" => Kind::Int16,
139-
"i1" => Kind::Int8,
140-
"u1" => Kind::Uint8,
138+
"e" | "f2" => Kind::Half,
139+
"f" | "f4" => Kind::Float,
140+
"d" | "f8" => Kind::Double,
141+
"i" | "i4" => Kind::Int,
142+
"q" | "i8" => Kind::Int64,
143+
"h" | "i2" => Kind::Int16,
144+
"b" | "i1" => Kind::Int8,
145+
"B" | "u1" => Kind::Uint8,
146+
"?" | "b1" => Kind::Bool,
147+
"F" | "F4" => Kind::ComplexFloat,
148+
"D" | "F8" => Kind::ComplexDouble,
141149
descr => {
142150
return Err(TchError::FileFormat(format!("unrecognized descr {}", descr)))
143151
}

0 commit comments

Comments
 (0)