Skip to content

Commit 9efb352

Browse files
committed
Add REDUCE_MIN to reduce kernel
@tensorflow/micro Add the REDUCE_MiN operator to the reduce kernel. Refactor reduce kernel to decrease number of methods in tflite namespace. Add REDUCE_MIN unit tests. Fix unit test axis data to match tensor shape. Make Xtensa reduce kernel use reference common code for REDUCE_MIN. bug=fixes Missing support for ReduceMin op #3108
1 parent 3200ccd commit 9efb352

File tree

5 files changed

+276
-117
lines changed

5 files changed

+276
-117
lines changed

tensorflow/lite/micro/kernels/reduce.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -28,15 +28,17 @@ limitations under the License.
2828

2929
namespace tflite {
3030

31+
namespace {
32+
3133
void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
3234
void* op_data =
3335
context->AllocatePersistentBuffer(context, sizeof(OpDataReduce));
3436
return new (op_data) OpDataReduce();
3537
}
3638

37-
TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
38-
return PrepareMaxHelper(context, node,
39-
static_cast<OpDataReduce*>(node->user_data));
39+
TfLiteStatus PrepareMinMax(TfLiteContext* context, TfLiteNode* node) {
40+
return PrepareMinMaxHelper(context, node,
41+
static_cast<OpDataReduce*>(node->user_data));
4042
}
4143

4244
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
@@ -54,17 +56,28 @@ TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
5456
return EvalMaxHelper(context, node, op_data);
5557
}
5658

59+
TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
60+
OpDataReduce* op_data = static_cast<OpDataReduce*>(node->user_data);
61+
return EvalMinHelper(context, node, op_data);
62+
}
63+
5764
TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
5865
return EvalSumHelper(context, node,
5966
static_cast<OpDataReduce*>(node->user_data));
6067
}
6168

69+
} // namespace
70+
6271
TFLMRegistration Register_MEAN() {
6372
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean);
6473
}
6574

6675
TFLMRegistration Register_REDUCE_MAX() {
67-
return tflite::micro::RegisterOp(InitReduce, PrepareMax, EvalMax);
76+
return tflite::micro::RegisterOp(InitReduce, PrepareMinMax, EvalMax);
77+
}
78+
79+
TFLMRegistration Register_REDUCE_MIN() {
80+
return tflite::micro::RegisterOp(InitReduce, PrepareMinMax, EvalMin);
6881
}
6982

7083
TFLMRegistration Register_SUM() {

tensorflow/lite/micro/kernels/reduce.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -40,24 +40,24 @@ struct OpDataReduce {
4040
int num_axis;
4141
};
4242

43-
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
44-
OpDataReduce* op_data);
43+
TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
44+
OpDataReduce* op_data);
4545

4646
TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
4747
OpDataReduce* op_data);
4848

4949
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
5050
OpDataReduce* op_data);
51+
TfLiteStatus EvalMinHelper(TfLiteContext* context, TfLiteNode* node,
52+
OpDataReduce* op_data);
5153
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
5254
OpDataReduce* op_data);
5355
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
5456
OpDataReduce* op_data);
5557

56-
void ReduceResolveAxis(const int* axis_data, int axis_count,
57-
MeanParams* op_params);
58-
5958
TFLMRegistration Register_MEAN();
6059
TFLMRegistration Register_REDUCE_MAX();
60+
TFLMRegistration Register_REDUCE_MIN();
6161
TFLMRegistration Register_SUM();
6262

6363
} // namespace tflite

tensorflow/lite/micro/kernels/reduce_common.cc

Lines changed: 140 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -31,6 +31,8 @@ namespace tflite {
3131
const int kMaxNumberOfAxis = 5;
3232
const int kMaxNumberOfReducedAxis = 2;
3333

34+
namespace {
35+
3436
TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
3537
int32_t* multiplier, int* shift) {
3638
MicroContext* micro_context = GetMicroContext(context);
@@ -64,8 +66,138 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
6466
return kTfLiteOk;
6567
}
6668

67-
TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
68-
OpDataReduce* op_data) {
69+
void ResolveAxis(const int* axis_data, int axis_count,
70+
tflite::MeanParams* op_params) {
71+
int i = 0;
72+
for (; i < axis_count; ++i) {
73+
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
74+
}
75+
for (; i < 4; ++i) {
76+
op_params->axis[i] = 1;
77+
}
78+
op_params->axis_count = axis_count;
79+
}
80+
81+
template <typename T>
82+
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
83+
int* temp_index, int* resolved_axis,
84+
int32_t* temp_sum, OpDataReduce* op_data,
85+
bool compute_sum) {
86+
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
87+
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
88+
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
89+
TfLiteReducerParams* params =
90+
static_cast<TfLiteReducerParams*>(node->builtin_data);
91+
92+
bool result = reference_ops::QuantizedMeanOrSumExtraArgs<T, int32_t>(
93+
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
94+
op_data->input_scale, &input->dims->data[0], input->dims->size,
95+
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
96+
op_data->multiplier, op_data->shift, op_data->output_zp,
97+
&output->dims->data[0], output->dims->size,
98+
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
99+
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
100+
TF_LITE_ENSURE(context, result);
101+
102+
return kTfLiteOk;
103+
}
104+
105+
template <typename integer_type>
106+
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
107+
int num_axis, OpDataReduce* op_data,
108+
int* temp_index, int* resolved_axis) {
109+
int32_t* temp_sum = static_cast<int32_t*>(
110+
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
111+
112+
QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
113+
temp_sum, op_data, /*compute_sum=*/false);
114+
115+
return kTfLiteOk;
116+
}
117+
118+
enum MinMaxEvalType { kEvalMin, kEvalMax };
119+
120+
template <typename T>
121+
struct MinMaxReducerParams {
122+
MinMaxReducerParams() = delete;
123+
MinMaxReducerParams(MinMaxEvalType evalType) : type_(evalType) {};
124+
125+
constexpr T initialValue() const {
126+
return (type_ == kEvalMin) ? std::numeric_limits<T>::max()
127+
: std::numeric_limits<T>::lowest();
128+
}
129+
130+
// should be able to use "auto" keyword here, but GCC and Clang blow a fuse
131+
T (*compare())(const T, const T) {
132+
if (type_ == kEvalMin) {
133+
return [](const T current, const T in) -> T {
134+
return (in < current) ? in : current;
135+
};
136+
} else {
137+
return [](const T current, const T in) -> T {
138+
return (in > current) ? in : current;
139+
};
140+
}
141+
}
142+
143+
private:
144+
MinMaxEvalType type_;
145+
};
146+
147+
TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
148+
OpDataReduce* op_data, MinMaxEvalType evalType) {
149+
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
150+
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
151+
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
152+
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
153+
TfLiteReducerParams* params =
154+
static_cast<TfLiteReducerParams*>(node->builtin_data);
155+
156+
// Interpret an axis tensor with null dimensions as a scalar
157+
int num_axis = static_cast<int>(ElementCount(*axis->dims));
158+
int* temp_buffer = static_cast<int*>(
159+
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
160+
int* resolved_axis = static_cast<int*>(
161+
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
162+
switch (input->type) {
163+
case kTfLiteFloat32: {
164+
MinMaxReducerParams<float> reducer(evalType);
165+
TF_LITE_ENSURE(
166+
context,
167+
reference_ops::ReduceGeneric<float>(
168+
tflite::micro::GetTensorData<float>(input), input->dims->data,
169+
input->dims->size, tflite::micro::GetTensorData<float>(output),
170+
output->dims->data, output->dims->size,
171+
tflite::micro::GetTensorData<int>(axis), num_axis,
172+
params->keep_dims, temp_buffer, resolved_axis,
173+
reducer.initialValue(), reducer.compare()));
174+
} break;
175+
case kTfLiteInt8: {
176+
MinMaxReducerParams<int8_t> reducer(evalType);
177+
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
178+
static_cast<double>(op_data->output_scale));
179+
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
180+
TF_LITE_ENSURE(
181+
context,
182+
reference_ops::ReduceGeneric<int8_t>(
183+
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
184+
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
185+
output->dims->data, output->dims->size,
186+
tflite::micro::GetTensorData<int>(axis), num_axis,
187+
params->keep_dims, temp_buffer, resolved_axis,
188+
reducer.initialValue(), reducer.compare()));
189+
} break;
190+
default:
191+
MicroPrintf("Only float32 and int8 types are supported.");
192+
return kTfLiteError;
193+
}
194+
return kTfLiteOk;
195+
}
196+
197+
} // namespace
198+
199+
TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
200+
OpDataReduce* op_data) {
69201
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier,
70202
&op_data->shift));
71203

@@ -126,55 +258,6 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
126258
return kTfLiteOk;
127259
}
128260

129-
void ResolveAxis(const int* axis_data, int axis_count,
130-
tflite::MeanParams* op_params) {
131-
int i = 0;
132-
for (; i < axis_count; ++i) {
133-
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
134-
}
135-
for (; i < 4; ++i) {
136-
op_params->axis[i] = 1;
137-
}
138-
op_params->axis_count = axis_count;
139-
}
140-
141-
template <typename T>
142-
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
143-
int* temp_index, int* resolved_axis,
144-
int32_t* temp_sum, OpDataReduce* op_data,
145-
bool compute_sum) {
146-
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
147-
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
148-
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
149-
TfLiteReducerParams* params =
150-
static_cast<TfLiteReducerParams*>(node->builtin_data);
151-
152-
bool result = reference_ops::QuantizedMeanOrSumExtraArgs<T, int32_t>(
153-
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
154-
op_data->input_scale, &input->dims->data[0], input->dims->size,
155-
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
156-
op_data->multiplier, op_data->shift, op_data->output_zp,
157-
&output->dims->data[0], output->dims->size,
158-
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
159-
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
160-
TF_LITE_ENSURE(context, result);
161-
162-
return kTfLiteOk;
163-
}
164-
165-
template <typename integer_type>
166-
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
167-
int num_axis, OpDataReduce* op_data,
168-
int* temp_index, int* resolved_axis) {
169-
int32_t* temp_sum = static_cast<int32_t*>(
170-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
171-
172-
QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
173-
temp_sum, op_data, /*compute_sum=*/false);
174-
175-
return kTfLiteOk;
176-
}
177-
178261
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
179262
OpDataReduce* op_data) {
180263
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
@@ -238,56 +321,12 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
238321

239322
TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
240323
OpDataReduce* op_data) {
241-
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
242-
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
243-
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
244-
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
245-
TfLiteReducerParams* params =
246-
static_cast<TfLiteReducerParams*>(node->builtin_data);
324+
return EvalMinMaxHelper(context, node, op_data, kEvalMax);
325+
}
247326

248-
// Interpret an axis tensor with null dimensions as a scalar
249-
int num_axis = static_cast<int>(ElementCount(*axis->dims));
250-
int* temp_buffer = static_cast<int*>(
251-
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
252-
int* resolved_axis = static_cast<int*>(
253-
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
254-
switch (input->type) {
255-
case kTfLiteFloat32:
256-
TF_LITE_ENSURE(
257-
context,
258-
reference_ops::ReduceGeneric<float>(
259-
tflite::micro::GetTensorData<float>(input), input->dims->data,
260-
input->dims->size, tflite::micro::GetTensorData<float>(output),
261-
output->dims->data, output->dims->size,
262-
tflite::micro::GetTensorData<int>(axis), num_axis,
263-
params->keep_dims, temp_buffer, resolved_axis,
264-
std::numeric_limits<float>::lowest(),
265-
[](const float current, const float in) -> float {
266-
return (in > current) ? in : current;
267-
}));
268-
break;
269-
case kTfLiteInt8:
270-
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
271-
static_cast<double>(op_data->output_scale));
272-
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
273-
TF_LITE_ENSURE(
274-
context,
275-
reference_ops::ReduceGeneric<int8_t>(
276-
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
277-
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
278-
output->dims->data, output->dims->size,
279-
tflite::micro::GetTensorData<int>(axis), num_axis,
280-
params->keep_dims, temp_buffer, resolved_axis,
281-
std::numeric_limits<int8_t>::lowest(),
282-
[](const int8_t current, const int8_t in) -> int8_t {
283-
return (in > current) ? in : current;
284-
}));
285-
break;
286-
default:
287-
MicroPrintf("Only float32 and int8 types are supported.");
288-
return kTfLiteError;
289-
}
290-
return kTfLiteOk;
327+
TfLiteStatus EvalMinHelper(TfLiteContext* context, TfLiteNode* node,
328+
OpDataReduce* op_data) {
329+
return EvalMinMaxHelper(context, node, op_data, kEvalMin);
291330
}
292331

293332
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,

0 commit comments

Comments
 (0)