Skip to content

Commit 3200ccd

Browse files
authored
Automated sync from github.com/tensorflow/tensorflow (#3122)
BUG=automated sync from upstream NO_CHECK_TFLITE_FILES=automated sync from upstream
1 parent 40cf8cb commit 3200ccd

File tree

1 file changed

+80
-0
lines changed
  • tensorflow/lite/kernels/internal/reference

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REVERSE_H_
16+
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REVERSE_H_
17+
18+
#include <algorithm>
19+
#include <array>
20+
#include <cstdint>
21+
22+
#include "ruy/profiler/instrumentation.h" // from @ruy
23+
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
24+
25+
namespace tflite {
26+
namespace reference_ops {
27+
28+
template <typename Scalar>
29+
void Reverse(std::array<int32_t, 8>& axes, int num_axes,
30+
const RuntimeShape& input_shape, const Scalar* input_data,
31+
Scalar* output_data) {
32+
ruy::profiler::ScopeLabel label("Reverse");
33+
bool is_upper = (axes[num_axes - 1] == input_shape.DimensionsCount() - 1);
34+
bool is_lower = (axes[0] == 0);
35+
int rank = input_shape.DimensionsCount();
36+
if (is_upper && is_lower) {
37+
std::reverse_copy(input_data, input_data + input_shape.FlatSize(),
38+
output_data);
39+
return;
40+
} else {
41+
int32_t min_dim = axes[0];
42+
int32_t max_dim = axes[num_axes - 1];
43+
int upper_size = 1;
44+
for (int i = 0; i < min_dim; ++i) {
45+
upper_size *= input_shape.Dims(i);
46+
}
47+
int lower_size = 1;
48+
for (int i = max_dim + 1; i < rank; ++i) {
49+
lower_size *= input_shape.Dims(i);
50+
}
51+
int middle_size = 1;
52+
for (int i = min_dim; i <= max_dim; ++i) {
53+
middle_size *= input_shape.Dims(i);
54+
}
55+
56+
if (lower_size > 1) {
57+
for (int i = 0; i < upper_size; ++i) {
58+
for (int j = 0; j < middle_size; ++j) {
59+
Scalar* src =
60+
(Scalar*)input_data + (i * (middle_size) + j) * lower_size;
61+
Scalar* dst =
62+
(Scalar*)output_data +
63+
(i * (middle_size) + (middle_size - j - 1)) * lower_size;
64+
memcpy(dst, src, lower_size * sizeof(Scalar));
65+
}
66+
}
67+
} else {
68+
for (int i = 0; i < upper_size; ++i) {
69+
std::reverse_copy(input_data + i * (middle_size),
70+
input_data + i * middle_size + middle_size,
71+
output_data + i * (middle_size));
72+
}
73+
}
74+
}
75+
}
76+
77+
} // namespace reference_ops
78+
} // namespace tflite
79+
80+
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REVERSE_H_

0 commit comments

Comments
 (0)