@@ -129,15 +129,23 @@ impl Header {
129
129
if descr. starts_with ( '>' ) {
130
130
return Err ( TchError :: FileFormat ( format ! ( "little-endian descr {}" , descr) ) ) ;
131
131
}
132
+ // the only supported types in tensor are:
133
+ // float64, float32, float16,
134
+ // complex64, complex128,
135
+ // int64, int32, int16, int8,
136
+ // uint8, and bool.
132
137
match descr. trim_matches ( |c : char | c == '=' || c == '<' || c == '|' ) {
133
- "f2" => Kind :: Half ,
134
- "f4" => Kind :: Float ,
135
- "f8" => Kind :: Double ,
136
- "i4" => Kind :: Int ,
137
- "i8" => Kind :: Int64 ,
138
- "i2" => Kind :: Int16 ,
139
- "i1" => Kind :: Int8 ,
140
- "u1" => Kind :: Uint8 ,
138
+ "e" | "f2" => Kind :: Half ,
139
+ "f" | "f4" => Kind :: Float ,
140
+ "d" | "f8" => Kind :: Double ,
141
+ "i" | "i4" => Kind :: Int ,
142
+ "q" | "i8" => Kind :: Int64 ,
143
+ "h" | "i2" => Kind :: Int16 ,
144
+ "b" | "i1" => Kind :: Int8 ,
145
+ "B" | "u1" => Kind :: Uint8 ,
146
+ "?" | "b1" => Kind :: Bool ,
147
+ "F" | "F4" => Kind :: ComplexFloat ,
148
+ "D" | "F8" => Kind :: ComplexDouble ,
141
149
descr => {
142
150
return Err ( TchError :: FileFormat ( format ! ( "unrecognized descr {}" , descr) ) )
143
151
}
0 commit comments