1
+ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include < ctime>
17
+ #include < iostream>
18
+ #include < orc/Exceptions.hh>
19
+ #include < orc/OrcFile.hh>
20
+ #include < orc/Reader.hh>
21
+ #include < orc/Type.hh>
22
+
23
+ #include " orc/orc-config.hh"
24
+ #include " tensorflow/core/lib/io/buffered_inputstream.h"
25
+ #include " tensorflow_io/core/kernels/io_interface.h"
26
+ #include " tensorflow_io/core/kernels/io_stream.h"
27
+
28
+ namespace tensorflow {
29
+ namespace data {
30
+
31
+ class ORCReadable : public IOReadableInterface {
32
+ public:
33
+ ORCReadable (Env* env) : env_(env) {}
34
+ ~ORCReadable () {}
35
+ Status Init (const std::vector<string>& input,
36
+ const std::vector<string>& metadata, const void * memory_data,
37
+ const int64 memory_size) override {
38
+ if (input.size () > 1 ) {
39
+ return errors::InvalidArgument (" more than 1 filename is not supported" );
40
+ }
41
+ const string& filename = input[0 ];
42
+ // read packet data
43
+ orc::RowReaderOptions row_reader_opts;
44
+ orc::ReaderOptions reader_opts;
45
+ std::unique_ptr<orc::Reader> reader =
46
+ orc::createReader (orc::readFile (filename), reader_opts);
47
+
48
+ row_reader_ = reader->createRowReader (row_reader_opts);
49
+ LOG (INFO) << " ORC file schema:" << reader->getType ().toString ();
50
+
51
+ // Parse columns. We assume the orc record file is a flat array
52
+ auto row_count = reader->getNumberOfRows ();
53
+ for (uint64_t i = 0 ; i < reader->getType ().getSubtypeCount (); ++i) {
54
+ auto field_name = reader->getType ().getFieldName (i);
55
+ auto subtype = reader->getType ().getSubtype (i);
56
+ DataType dtype;
57
+ switch (static_cast <int64_t >(subtype->getKind ())) {
58
+ case orc::SHORT:
59
+ dtype = DT_INT16;
60
+ break ;
61
+ case orc::INT:
62
+ dtype = DT_INT32;
63
+ break ;
64
+ case orc::LONG:
65
+ dtype = DT_INT64;
66
+ break ;
67
+ case orc::STRING:
68
+ dtype = DT_STRING;
69
+ break ;
70
+ case orc::DOUBLE:
71
+ dtype = DT_DOUBLE;
72
+ break ;
73
+ case orc::FLOAT:
74
+ dtype = DT_FLOAT;
75
+ break ;
76
+ default :
77
+ return errors::InvalidArgument (" data type is not supported: " ,
78
+ subtype->toString ());
79
+ }
80
+ columns_.push_back (field_name);
81
+ shapes_.push_back (TensorShape ({static_cast <int64>(row_count)}));
82
+ dtypes_.push_back (dtype);
83
+ columns_index_[field_name] = i;
84
+ tensors_.emplace_back (
85
+ Tensor (dtype, TensorShape ({static_cast <int64>(row_count)})));
86
+ }
87
+ // Fill in the values
88
+ std::unique_ptr<orc::ColumnVectorBatch> batch =
89
+ row_reader_->createRowBatch (10 );
90
+ auto * fields = dynamic_cast <orc::StructVectorBatch*>(batch.get ());
91
+ int64_t record_index = 0 ;
92
+ // Template type conversions between ORC and TensorFlow DT
93
+ #define PROCESS_TYPE (VTYPE, VDTYPE, TDTYPE ) \
94
+ { \
95
+ auto * col = dynamic_cast <VTYPE>(fields->fields [column_index]); \
96
+ VDTYPE* buffer1 = col->data .data (); \
97
+ tensors_[column_index].flat <TDTYPE>()(record_index) = (TDTYPE)buffer1[r]; \
98
+ }
99
+ while (row_reader_->next (*batch)) {
100
+ for (uint32_t r = 0 ; r < batch->numElements ; ++r) {
101
+ for (size_t column_index = 0 ; column_index < columns_.size ();
102
+ column_index++) {
103
+ switch (dtypes_[column_index]) {
104
+ case DT_DOUBLE:
105
+ PROCESS_TYPE (orc::DoubleVectorBatch*, double , double );
106
+ break ;
107
+ case DT_FLOAT:
108
+ PROCESS_TYPE (orc::DoubleVectorBatch*, double , float );
109
+ break ;
110
+ case DT_INT16:
111
+ PROCESS_TYPE (orc::LongVectorBatch*, int64, int16);
112
+ break ;
113
+ case DT_INT32:
114
+ PROCESS_TYPE (orc::LongVectorBatch*, int64, int32);
115
+ break ;
116
+ case DT_INT64:
117
+ PROCESS_TYPE (orc::LongVectorBatch*, int64, int64);
118
+ break ;
119
+ case DT_STRING: {
120
+ auto * string_col = dynamic_cast <orc::StringVectorBatch*>(
121
+ fields->fields [column_index]);
122
+ char ** buffer = string_col->data .data ();
123
+ int64_t * lengths = string_col->length .data ();
124
+ tensors_[column_index].flat <tstring>()(record_index) =
125
+ std::string (buffer[r], lengths[r]);
126
+ break ;
127
+ }
128
+ default :
129
+ return errors::InvalidArgument (
130
+ " data type is not supported: " ,
131
+ DataTypeString (dtypes_[column_index]));
132
+ }
133
+ }
134
+ record_index++;
135
+ }
136
+ }
137
+
138
+ return Status::OK ();
139
+ }
140
+
141
+ Status Read (const int64 start, const int64 stop, const string& component,
142
+ int64* record_read, Tensor* value, Tensor* label) override {
143
+ if (columns_index_.find (component) == columns_index_.end ()) {
144
+ return errors::InvalidArgument (" component " , component, " is invalid" );
145
+ }
146
+ int64 column_index = columns_index_[component];
147
+
148
+ (*record_read) = 0 ;
149
+ if (start >= shapes_[column_index].dim_size (0 )) {
150
+ return Status::OK ();
151
+ }
152
+ const string& column = component;
153
+ int64 element_start = start < shapes_[column_index].dim_size (0 )
154
+ ? start
155
+ : shapes_[column_index].dim_size (0 );
156
+ int64 element_stop = stop < shapes_[column_index].dim_size (0 )
157
+ ? stop
158
+ : shapes_[column_index].dim_size (0 );
159
+ if (element_start > element_stop) {
160
+ return errors::InvalidArgument (" dataset " , column,
161
+ " selection is out of boundary" );
162
+ }
163
+ if (element_start == element_stop) {
164
+ return Status::OK ();
165
+ }
166
+
167
+ #define PROCESS_VALUE (VTYPE ) \
168
+ { \
169
+ value->flat <VTYPE>().data ()[i] = \
170
+ tensors_[column_index].flat <VTYPE>().data ()[i]; \
171
+ }
172
+ for (int i = element_start; i < element_stop; i++) {
173
+ switch (dtypes_[column_index]) {
174
+ case DT_DOUBLE:
175
+ PROCESS_VALUE (double );
176
+ break ;
177
+ case DT_FLOAT:
178
+ PROCESS_VALUE (float );
179
+ break ;
180
+ case DT_INT16:
181
+ PROCESS_VALUE (int16);
182
+ break ;
183
+ case DT_INT32:
184
+ PROCESS_VALUE (int32);
185
+ break ;
186
+ case DT_INT64:
187
+ PROCESS_VALUE (int64);
188
+ break ;
189
+ case DT_STRING: {
190
+ PROCESS_VALUE (tstring);
191
+ break ;
192
+ }
193
+ default :
194
+ return errors::InvalidArgument (" data type is not supported: " ,
195
+ DataTypeString (dtypes_[column_index]));
196
+ }
197
+ }
198
+ (*record_read) = element_stop - element_start;
199
+
200
+ return Status::OK ();
201
+ }
202
+
203
+ Status Components (std::vector<string>* components) override {
204
+ components->clear ();
205
+ for (size_t i = 0 ; i < columns_.size (); i++) {
206
+ components->push_back (columns_[i]);
207
+ }
208
+ return Status::OK ();
209
+ }
210
+
211
+ Status Spec (const string& component, PartialTensorShape* shape,
212
+ DataType* dtype, bool label) override {
213
+ if (columns_index_.find (component) == columns_index_.end ()) {
214
+ return errors::InvalidArgument (" component " , component, " is invalid" );
215
+ }
216
+ int64 column_index = columns_index_[component];
217
+ *shape = shapes_[column_index];
218
+ *dtype = dtypes_[column_index];
219
+ return Status::OK ();
220
+ }
221
+
222
+ string DebugString () const override {
223
+ mutex_lock l (mu_);
224
+ return strings::StrCat (" ORCReadable" );
225
+ }
226
+
227
+ private:
228
+ mutable mutex mu_;
229
+ Env* env_ TF_GUARDED_BY (mu_);
230
+ std::unique_ptr<SizedRandomAccessFile> file_ TF_GUARDED_BY (mu_);
231
+ std::unique_ptr<orc::RowReader> row_reader_ TF_GUARDED_BY (mu_);
232
+ std::vector<Tensor> tensors_;
233
+
234
+ std::vector<DataType> dtypes_;
235
+ std::vector<TensorShape> shapes_;
236
+ std::vector<string> columns_;
237
+ std::unordered_map<string, int64> columns_index_;
238
+ };
239
+ REGISTER_KERNEL_BUILDER (Name(" IO>ORCReadableInit" ).Device(DEVICE_CPU),
240
+ IOInterfaceInitOp<ORCReadable>);
241
+ REGISTER_KERNEL_BUILDER (Name(" IO>ORCReadableSpec" ).Device(DEVICE_CPU),
242
+ IOInterfaceSpecOp<ORCReadable>);
243
+ REGISTER_KERNEL_BUILDER (Name(" IO>ORCReadableRead" ).Device(DEVICE_CPU),
244
+ IOReadableReadOp<ORCReadable>);
245
+ } // namespace data
246
+ } // namespace tensorflow
0 commit comments