Skip to content

Porting Reverse_V2 operator from TFLite #3123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions python/tflite_micro/python_ops_resolver.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright year.

Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddReshape();
AddResizeBilinear();
AddResizeNearestNeighbor();
AddReverseV2();
AddRfft();
AddRound();
AddRsqrt();
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ tflm_kernel_cc_library(
"reshape_common.cc",
"resize_bilinear.cc",
"resize_nearest_neighbor.cc",
"reverse.cc",
"round.cc",
"select.cc",
"shape.cc",
Expand Down Expand Up @@ -1224,6 +1225,19 @@ tflm_cc_test(
],
)

tflm_cc_test(
name = "reverse_test",
srcs = [
"reverse_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs "//tensorflow/lite/micro:op_resolvers", after this line, in order to have dependency on micro_ops.

"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

tflm_cc_test(
name = "round_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/Makefile.inc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright year.

Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reduce_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reshape_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/resize_bilinear_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reverse_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/round_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/select_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/shape_test.cc \
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/micro_ops.h
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update copyright year.

Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ TFLMRegistration Register_RELU6();
TFLMRegistration Register_RESHAPE();
TFLMRegistration Register_RESIZE_BILINEAR();
TFLMRegistration Register_RESIZE_NEAREST_NEIGHBOR();
TFLMRegistration Register_REVERSE_V2();
TFLMRegistration Register_ROUND();
TFLMRegistration Register_RSQRT();
TFLMRegistration Register_SELECT_V2();
Expand Down
201 changes: 201 additions & 0 deletions tensorflow/lite/micro/kernels/reverse.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add kTfLiteInt64 and kTfLiteBool support, as per TfLite.

Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <stdint.h>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add <cstdlib> for qsort.
Add <cstring> for memcpy.

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"

namespace tflite {

constexpr int kMaxDimensions = 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be in the anonymous namespace.

Also, this should be set to the value of RuntimeShape::kMaxSmallSize, as RuntimeShape only supports 6 dimensions.


namespace {

constexpr int kInputTensor = 0;
constexpr int kAxisTensor = 1;
constexpr int kOutputTensor = 0;

int comp(const void* a, const void* b) { return (*(int*)a - *(int*)b); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use static_cast.


TfLiteStatus ReverseV2Prepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);

TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

// Ensure inputs and outputs exist.
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* axis =
micro_context->AllocateTempInputTensor(node, kAxisTensor);
TF_LITE_ENSURE(context, axis != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
TF_LITE_ENSURE(context, NumDimensions(input) <= 8);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use kMaxDimensions here.

TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));

if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
input->type != kTfLiteInt16) {
MicroPrintf("Type '%s' is not supported by reverse.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}

if (axis->type != kTfLiteInt32) {
MicroPrintf("Axis Type '%s' is not supported by reverse.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}
// The value type and output type must match.
TF_LITE_ENSURE_EQ(context, input->type, output->type);

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(axis);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}

template <typename T>
void ReverseImpl(int32_t* axes, int num_axes, const RuntimeShape& input_shape,
Comment on lines +79 to +80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should not be implemented. Instead please use tflite::reference_ops::Reverse().

const T* input_data, T* output_data) {
bool is_upper = (axes[num_axes - 1] == input_shape.DimensionsCount() - 1);
bool is_lower = (axes[0] == 0);
int rank = input_shape.DimensionsCount();
if (is_upper && is_lower) {
std::reverse_copy(input_data, input_data + input_shape.FlatSize(),
output_data);
return;
} else {
int32_t min_dim = axes[0];
int32_t max_dim = axes[num_axes - 1];
int upper_size = 1;
for (int i = 0; i < min_dim; ++i) {
upper_size *= input_shape.Dims(i);
}
int lower_size = 1;
for (int i = max_dim + 1; i < rank; ++i) {
lower_size *= input_shape.Dims(i);
}
int middle_size = 1;
for (int i = min_dim; i <= max_dim; ++i) {
middle_size *= input_shape.Dims(i);
}

if (lower_size > 1) {
for (int i = 0; i < upper_size; ++i) {
for (int j = 0; j < middle_size; ++j) {
T* src = (T*)input_data + (i * (middle_size) + j) * lower_size;
T* dst = (T*)output_data +
(i * (middle_size) + (middle_size - j - 1)) * lower_size;
memcpy(dst, src, lower_size * sizeof(T));
}
}
} else {
for (int i = 0; i < upper_size; ++i) {
std::reverse_copy(input_data + i * (middle_size),
input_data + i * middle_size + middle_size,
output_data + i * (middle_size));
}
}
}
}

TfLiteStatus ReverseV2Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* axis =
micro::GetEvalInput(context, node, kAxisTensor);
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);

TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not required (already done in Prepare phase).

const int num_axes = static_cast<int>(ElementCount(*axis->dims));
TF_LITE_ENSURE(context, num_axes <= 8);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not required (already done in Prepare phase).


int32_t axes_data[kMaxDimensions];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of std::array here is perfectly fine, as the allocation is done statically.

std::memcpy(axes_data, axis->data.i32, sizeof(int32_t) * num_axes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be axis->data.data. Use of any other member of the union is deprecated.

const int rank = tflite::micro::GetTensorShape(input).DimensionsCount();
for (int i = 0; i < num_axes; ++i) {
if (axes_data[i] < 0) {
axes_data[i] += rank;
}
TF_LITE_ENSURE(context, axes_data[i] >= 0 && axes_data[i] < rank);
}
qsort(axes_data, num_axes, sizeof(int32_t), comp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use std:: here.
FYI: qsort in the worst case has more complexity than std::sort in modern C++ standard libraries. But since we are talking about (at most) 6 elements to sort, I don't think it matters which one we use.


bool is_contiguous = true;
for (int i = 1; i < num_axes; ++i) {
if (axes_data[i - 1] + 1 != axes_data[i]) {
is_contiguous = false;
break;
}
}
if (!is_contiguous) {
MicroPrintf("Non-contiguous `axes` not supported");
return kTfLiteError;
}

switch (output->type) {
case kTfLiteFloat32:
ReverseImpl<float>(axes_data, num_axes,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteInt32:
ReverseImpl<int32_t>(axes_data, num_axes,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int32_t>(input),
tflite::micro::GetTensorData<int32_t>(output));
break;
case kTfLiteInt16:
ReverseImpl<int16_t>(axes_data, num_axes,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorData<int16_t>(output));
break;
case kTfLiteInt8:
case kTfLiteUInt8:
ReverseImpl<uint8_t>(axes_data, num_axes,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorData<uint8_t>(output));
Comment on lines +160 to +182
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use reference_ops::Reverse()

break;
default:
MicroPrintf(
"Reverse currently supports float32, int16, "
"int8 and uint8 for output, got %d.",
Comment on lines +186 to +187
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this to "Type '%s' is not supported", just to simplify the error message.

TfLiteTypeGetName(output->type));
return kTfLiteError;
}

return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_REVERSE_V2() {
return tflite::micro::RegisterOp(nullptr, ReverseV2Prepare, ReverseV2Eval);
}

} // namespace tflite
Loading
Loading