Skip to content

Commit f163ee1

Browse files
oliverhuyongtang
authored andcommitted
Implement ORC dataset reader (tensorflow#1383)
* Implement ORC dataset reader * support double, float and string types * add sample keras unit tests * reset unintended changes * add more datatypes * fix type in macOS, add test * address comments * fix a typo in float conversion
1 parent bbbde45 commit f163ee1

File tree

7 files changed

+538
-13
lines changed

7 files changed

+538
-13
lines changed

tensorflow_io/core/BUILD

+17-13
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,23 @@ cc_library(
368368
alwayslink = 1,
369369
)
370370

371+
cc_library(
372+
name = "orc_ops",
373+
srcs = [
374+
"kernels/orc/orc_kernels.cc",
375+
"ops/orc_ops.cc",
376+
],
377+
copts = tf_io_copts(),
378+
linkstatic = True,
379+
deps = [
380+
"//tensorflow_io/core:dataset_ops",
381+
"@liborc",
382+
"@local_config_tf//:libtensorflow_framework",
383+
"@local_config_tf//:tf_header_lib",
384+
],
385+
alwayslink = 1,
386+
)
387+
371388
cc_library(
372389
name = "text_ops",
373390
srcs = [
@@ -531,19 +548,6 @@ cc_library(
531548
alwayslink = 1,
532549
)
533550

534-
cc_library(
535-
name = "orc_ops",
536-
srcs = [
537-
],
538-
copts = tf_io_copts(),
539-
linkstatic = True,
540-
deps = [
541-
"//tensorflow_io/core:dataset_ops",
542-
"@liborc",
543-
],
544-
alwayslink = 1,
545-
)
546-
547551
cc_library(
548552
name = "numpy_ops",
549553
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <ctime>
17+
#include <iostream>
18+
#include <orc/Exceptions.hh>
19+
#include <orc/OrcFile.hh>
20+
#include <orc/Reader.hh>
21+
#include <orc/Type.hh>
22+
23+
#include "orc/orc-config.hh"
24+
#include "tensorflow/core/lib/io/buffered_inputstream.h"
25+
#include "tensorflow_io/core/kernels/io_interface.h"
26+
#include "tensorflow_io/core/kernels/io_stream.h"
27+
28+
namespace tensorflow {
29+
namespace data {
30+
31+
class ORCReadable : public IOReadableInterface {
32+
public:
33+
ORCReadable(Env* env) : env_(env) {}
34+
~ORCReadable() {}
35+
Status Init(const std::vector<string>& input,
36+
const std::vector<string>& metadata, const void* memory_data,
37+
const int64 memory_size) override {
38+
if (input.size() > 1) {
39+
return errors::InvalidArgument("more than 1 filename is not supported");
40+
}
41+
const string& filename = input[0];
42+
// read packet data
43+
orc::RowReaderOptions row_reader_opts;
44+
orc::ReaderOptions reader_opts;
45+
std::unique_ptr<orc::Reader> reader =
46+
orc::createReader(orc::readFile(filename), reader_opts);
47+
48+
row_reader_ = reader->createRowReader(row_reader_opts);
49+
LOG(INFO) << "ORC file schema:" << reader->getType().toString();
50+
51+
// Parse columns. We assume the orc record file is a flat array
52+
auto row_count = reader->getNumberOfRows();
53+
for (uint64_t i = 0; i < reader->getType().getSubtypeCount(); ++i) {
54+
auto field_name = reader->getType().getFieldName(i);
55+
auto subtype = reader->getType().getSubtype(i);
56+
DataType dtype;
57+
switch (static_cast<int64_t>(subtype->getKind())) {
58+
case orc::SHORT:
59+
dtype = DT_INT16;
60+
break;
61+
case orc::INT:
62+
dtype = DT_INT32;
63+
break;
64+
case orc::LONG:
65+
dtype = DT_INT64;
66+
break;
67+
case orc::STRING:
68+
dtype = DT_STRING;
69+
break;
70+
case orc::DOUBLE:
71+
dtype = DT_DOUBLE;
72+
break;
73+
case orc::FLOAT:
74+
dtype = DT_FLOAT;
75+
break;
76+
default:
77+
return errors::InvalidArgument("data type is not supported: ",
78+
subtype->toString());
79+
}
80+
columns_.push_back(field_name);
81+
shapes_.push_back(TensorShape({static_cast<int64>(row_count)}));
82+
dtypes_.push_back(dtype);
83+
columns_index_[field_name] = i;
84+
tensors_.emplace_back(
85+
Tensor(dtype, TensorShape({static_cast<int64>(row_count)})));
86+
}
87+
// Fill in the values
88+
std::unique_ptr<orc::ColumnVectorBatch> batch =
89+
row_reader_->createRowBatch(10);
90+
auto* fields = dynamic_cast<orc::StructVectorBatch*>(batch.get());
91+
int64_t record_index = 0;
92+
// Template type conversions between ORC and TensorFlow DT
93+
#define PROCESS_TYPE(VTYPE, VDTYPE, TDTYPE) \
94+
{ \
95+
auto* col = dynamic_cast<VTYPE>(fields->fields[column_index]); \
96+
VDTYPE* buffer1 = col->data.data(); \
97+
tensors_[column_index].flat<TDTYPE>()(record_index) = (TDTYPE)buffer1[r]; \
98+
}
99+
while (row_reader_->next(*batch)) {
100+
for (uint32_t r = 0; r < batch->numElements; ++r) {
101+
for (size_t column_index = 0; column_index < columns_.size();
102+
column_index++) {
103+
switch (dtypes_[column_index]) {
104+
case DT_DOUBLE:
105+
PROCESS_TYPE(orc::DoubleVectorBatch*, double, double);
106+
break;
107+
case DT_FLOAT:
108+
PROCESS_TYPE(orc::DoubleVectorBatch*, double, float);
109+
break;
110+
case DT_INT16:
111+
PROCESS_TYPE(orc::LongVectorBatch*, int64, int16);
112+
break;
113+
case DT_INT32:
114+
PROCESS_TYPE(orc::LongVectorBatch*, int64, int32);
115+
break;
116+
case DT_INT64:
117+
PROCESS_TYPE(orc::LongVectorBatch*, int64, int64);
118+
break;
119+
case DT_STRING: {
120+
auto* string_col = dynamic_cast<orc::StringVectorBatch*>(
121+
fields->fields[column_index]);
122+
char** buffer = string_col->data.data();
123+
int64_t* lengths = string_col->length.data();
124+
tensors_[column_index].flat<tstring>()(record_index) =
125+
std::string(buffer[r], lengths[r]);
126+
break;
127+
}
128+
default:
129+
return errors::InvalidArgument(
130+
"data type is not supported: ",
131+
DataTypeString(dtypes_[column_index]));
132+
}
133+
}
134+
record_index++;
135+
}
136+
}
137+
138+
return Status::OK();
139+
}
140+
141+
Status Read(const int64 start, const int64 stop, const string& component,
142+
int64* record_read, Tensor* value, Tensor* label) override {
143+
if (columns_index_.find(component) == columns_index_.end()) {
144+
return errors::InvalidArgument("component ", component, " is invalid");
145+
}
146+
int64 column_index = columns_index_[component];
147+
148+
(*record_read) = 0;
149+
if (start >= shapes_[column_index].dim_size(0)) {
150+
return Status::OK();
151+
}
152+
const string& column = component;
153+
int64 element_start = start < shapes_[column_index].dim_size(0)
154+
? start
155+
: shapes_[column_index].dim_size(0);
156+
int64 element_stop = stop < shapes_[column_index].dim_size(0)
157+
? stop
158+
: shapes_[column_index].dim_size(0);
159+
if (element_start > element_stop) {
160+
return errors::InvalidArgument("dataset ", column,
161+
" selection is out of boundary");
162+
}
163+
if (element_start == element_stop) {
164+
return Status::OK();
165+
}
166+
167+
#define PROCESS_VALUE(VTYPE) \
168+
{ \
169+
value->flat<VTYPE>().data()[i] = \
170+
tensors_[column_index].flat<VTYPE>().data()[i]; \
171+
}
172+
for (int i = element_start; i < element_stop; i++) {
173+
switch (dtypes_[column_index]) {
174+
case DT_DOUBLE:
175+
PROCESS_VALUE(double);
176+
break;
177+
case DT_FLOAT:
178+
PROCESS_VALUE(float);
179+
break;
180+
case DT_INT16:
181+
PROCESS_VALUE(int16);
182+
break;
183+
case DT_INT32:
184+
PROCESS_VALUE(int32);
185+
break;
186+
case DT_INT64:
187+
PROCESS_VALUE(int64);
188+
break;
189+
case DT_STRING: {
190+
PROCESS_VALUE(tstring);
191+
break;
192+
}
193+
default:
194+
return errors::InvalidArgument("data type is not supported: ",
195+
DataTypeString(dtypes_[column_index]));
196+
}
197+
}
198+
(*record_read) = element_stop - element_start;
199+
200+
return Status::OK();
201+
}
202+
203+
Status Components(std::vector<string>* components) override {
204+
components->clear();
205+
for (size_t i = 0; i < columns_.size(); i++) {
206+
components->push_back(columns_[i]);
207+
}
208+
return Status::OK();
209+
}
210+
211+
Status Spec(const string& component, PartialTensorShape* shape,
212+
DataType* dtype, bool label) override {
213+
if (columns_index_.find(component) == columns_index_.end()) {
214+
return errors::InvalidArgument("component ", component, " is invalid");
215+
}
216+
int64 column_index = columns_index_[component];
217+
*shape = shapes_[column_index];
218+
*dtype = dtypes_[column_index];
219+
return Status::OK();
220+
}
221+
222+
string DebugString() const override {
223+
mutex_lock l(mu_);
224+
return strings::StrCat("ORCReadable");
225+
}
226+
227+
private:
228+
mutable mutex mu_;
229+
Env* env_ TF_GUARDED_BY(mu_);
230+
std::unique_ptr<SizedRandomAccessFile> file_ TF_GUARDED_BY(mu_);
231+
std::unique_ptr<orc::RowReader> row_reader_ TF_GUARDED_BY(mu_);
232+
std::vector<Tensor> tensors_;
233+
234+
std::vector<DataType> dtypes_;
235+
std::vector<TensorShape> shapes_;
236+
std::vector<string> columns_;
237+
std::unordered_map<string, int64> columns_index_;
238+
};
239+
REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableInit").Device(DEVICE_CPU),
240+
IOInterfaceInitOp<ORCReadable>);
241+
REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableSpec").Device(DEVICE_CPU),
242+
IOInterfaceSpecOp<ORCReadable>);
243+
REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableRead").Device(DEVICE_CPU),
244+
IOReadableReadOp<ORCReadable>);
245+
} // namespace data
246+
} // namespace tensorflow

tensorflow_io/core/ops/orc_ops.cc

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
REGISTER_OP("IO>ORCReadableInit")
22+
.Input("input: string")
23+
.Output("resource: resource")
24+
.Output("components: string")
25+
.Attr("container: string = ''")
26+
.Attr("shared_name: string = ''")
27+
.SetShapeFn([](shape_inference::InferenceContext* c) {
28+
c->set_output(0, c->Scalar());
29+
c->set_output(1, c->MakeShape({}));
30+
return Status::OK();
31+
});
32+
33+
REGISTER_OP("IO>ORCReadableSpec")
34+
.Input("input: resource")
35+
.Output("shape: int64")
36+
.Output("dtype: int64")
37+
.Attr("component: string")
38+
.SetShapeFn([](shape_inference::InferenceContext* c) {
39+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
40+
c->set_output(1, c->MakeShape({}));
41+
return Status::OK();
42+
});
43+
44+
REGISTER_OP("IO>ORCReadableRead")
45+
.Input("input: resource")
46+
.Input("start: int64")
47+
.Input("stop: int64")
48+
.Output("value: dtype")
49+
.Attr("component: string")
50+
.Attr("shape: shape")
51+
.Attr("dtype: type")
52+
.SetShapeFn([](shape_inference::InferenceContext* c) {
53+
PartialTensorShape shape;
54+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
55+
shape_inference::ShapeHandle entry;
56+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
57+
c->set_output(0, entry);
58+
return Status::OK();
59+
});
60+
} // namespace tensorflow

tensorflow_io/core/python/ops/io_dataset.py

+16
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow_io.core.python.ops import parquet_dataset_ops
2727
from tensorflow_io.core.python.ops import pcap_dataset_ops
2828
from tensorflow_io.core.python.ops import mnist_dataset_ops
29+
from tensorflow_io.core.python.ops import orc_dataset_ops
2930

3031

3132
class IODataset(io_dataset_ops._IODataset): # pylint: disable=protected-access
@@ -308,6 +309,21 @@ def from_pcap(cls, filename, **kwargs):
308309
with tf.name_scope(kwargs.get("name", "IOFromPcap")):
309310
return pcap_dataset_ops.PcapIODataset(filename, internal=True, **kwargs)
310311

312+
@classmethod
313+
def from_orc(cls, filename, **kwargs):
314+
"""Creates an `IODataset` from an ORC file.
315+
316+
Args:
317+
filename: A string, the filename of an ORC file.
318+
name: A name prefix for the IOTensor (optional).
319+
320+
Returns:
321+
A `IODataset`.
322+
323+
"""
324+
with tf.name_scope(kwargs.get("name", "IOFromORC")):
325+
return orc_dataset_ops.ORCIODataset(filename, internal=True, **kwargs)
326+
311327

312328
class StreamIODataset(
313329
io_dataset_ops._StreamIODataset

0 commit comments

Comments
 (0)