Skip to content

Commit 6b77bc7

Browse files
authored
refactor: Use safer backend APIs (#111)
1 parent e4c91d4 commit 6b77bc7

File tree

1 file changed

+54
-18
lines changed

1 file changed

+54
-18
lines changed

src/instance_state.cc

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,11 @@ ModelInstanceState::Run(
636636
"error setting the binding dimension");
637637

638638
TRITONSERVER_DataType datatype = batch_input.DataType();
639-
size_t total_byte_size = GetByteSize(datatype, shape);
639+
int64_t total_byte_size = 0;
640+
FAIL_ALL_AND_RETURN_IF_ERROR(
641+
payload_->requests_, payload_->request_count_, payload_->responses_,
642+
GetByteSize(datatype, shape, &total_byte_size),
643+
"error getting the batch input byte size");
640644

641645
const char* dst_buffer;
642646
size_t dst_buffer_byte_size;
@@ -690,7 +694,12 @@ ModelInstanceState::Run(
690694
"'")
691695
.c_str());
692696

693-
ragged_shape[0] += backend::GetElementCount(shape, dims_count);
697+
int64_t element_cnt = 0;
698+
FAIL_ALL_AND_RETURN_IF_ERROR(
699+
payload_->requests_, payload_->request_count_, payload_->responses_,
700+
backend::GetElementCount(shape, dims_count, &element_cnt),
701+
"error getting the input element count");
702+
ragged_shape[0] += element_cnt;
694703
if (req_idx == 0) {
695704
datatype = temp_dt;
696705
}
@@ -702,7 +711,11 @@ ModelInstanceState::Run(
702711
name, ragged_shape, citr->second, io_index, &input_dims),
703712
"error setting the binding dimension");
704713

705-
size_t total_byte_size = GetByteSize(datatype, ragged_shape);
714+
int64_t total_byte_size = 0;
715+
FAIL_ALL_AND_RETURN_IF_ERROR(
716+
payload_->requests_, payload_->request_count_, payload_->responses_,
717+
GetByteSize(datatype, ragged_shape, &total_byte_size),
718+
"error getting the input byte size");
706719

707720
payload_->collector_->ProcessTensor(
708721
name.c_str(), static_cast<char*>(io_binding_info.GetBuffer()),
@@ -758,17 +771,23 @@ ModelInstanceState::Run(
758771
"error setting the binding dimension");
759772
}
760773

761-
size_t total_byte_size = 0;
774+
int64_t total_byte_size = 0;
762775
if (io_binding_info.GetFormat().is_linear_format_) {
763-
total_byte_size = GetByteSize(datatype, batchn_shape);
776+
FAIL_ALL_AND_RETURN_IF_ERROR(
777+
payload_->requests_, payload_->request_count_, payload_->responses_,
778+
GetByteSize(datatype, batchn_shape, &total_byte_size),
779+
"error getting the batch input byte size");
764780
// For input tensors with a linear IO format, the request has already
765781
// verified the byte size, so no further validation is needed here.
766782
} else {
767783
batchn_shape[io_binding_info.GetFormat().vectorized_dim_] +=
768784
(io_binding_info.GetFormat().components_per_element_ -
769785
(batchn_shape[io_binding_info.GetFormat().vectorized_dim_] %
770786
io_binding_info.GetFormat().components_per_element_));
771-
total_byte_size = GetByteSize(datatype, batchn_shape);
787+
FAIL_ALL_AND_RETURN_IF_ERROR(
788+
payload_->requests_, payload_->request_count_, payload_->responses_,
789+
GetByteSize(datatype, batchn_shape, &total_byte_size),
790+
"error getting the batch input byte size");
772791

773792
// Ensure the request data byte size matches the expected byte size for
774793
// non-linear IO format tensors
@@ -823,8 +842,13 @@ ModelInstanceState::Run(
823842
// Initialize additional entries in batch input
824843
if (io_binding_info.GetBatchInput() != nullptr) {
825844
const auto& batch_input = io_binding_info.GetBatchInput()->first;
826-
const size_t total_byte_size = GetByteSize(
827-
batch_input.DataType(), cuda_graph->input_dims_[input_idx]);
845+
int64_t total_byte_size = 0;
846+
FAIL_ALL_AND_RETURN_IF_ERROR(
847+
payload_->requests_, payload_->request_count_, payload_->responses_,
848+
GetByteSize(
849+
batch_input.DataType(), cuda_graph->input_dims_[input_idx],
850+
&total_byte_size),
851+
"error getting the batch input byte size");
828852

829853
auto& allocated_memory = io_binding_info.GetBatchInput()->second;
830854
TRITONSERVER_MemoryType mem_type = allocated_memory->MemoryType();
@@ -841,7 +865,7 @@ ModelInstanceState::Run(
841865
batch_input, input_buffer, total_byte_size,
842866
{{mem_type, mem_type_id}}, &dst_buffer, &dst_buffer_byte_size,
843867
&dst_memory_type, &dst_memory_type_id),
844-
"error setting the bath input value");
868+
"error setting the batch input value");
845869

846870
if ((batch_input.BatchInputKind() !=
847871
BatchInput::Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE) &&
@@ -1067,8 +1091,10 @@ ModelInstanceState::Run(
10671091
batchn_shape[0] = shape[0];
10681092
}
10691093

1070-
const size_t tensor_element_cnt =
1071-
backend::GetElementCount(batchn_shape);
1094+
int64_t tensor_element_cnt = 0;
1095+
RESPOND_AND_SET_NULL_IF_ERROR(
1096+
&response,
1097+
backend::GetElementCount(batchn_shape, &tensor_element_cnt));
10721098

10731099
TRITONSERVER_DataType dt = ConvertTrtTypeToDataType(
10741100
engine_->getTensorDataType(name.c_str()));
@@ -1112,7 +1138,11 @@ ModelInstanceState::Run(
11121138
// FIXME process reformat-free output, need to update output
11131139
// process code to accept batch1_byte_size and request batch
11141140
// size to break down output buffer properly.
1115-
size_t batch1_byte_size = GetByteSize(dt, batchn_shape);
1141+
int64_t batch1_byte_size = 0;
1142+
FAIL_ALL_AND_RETURN_IF_ERROR(
1143+
payload_->requests_, payload_->request_count_, payload_->responses_,
1144+
GetByteSize(dt, batchn_shape, &batch1_byte_size),
1145+
"error getting the batch byte size");
11161146
if (support_batching_) {
11171147
batch1_byte_size /= payload_->total_batch_size_;
11181148
}
@@ -1371,7 +1401,9 @@ ModelInstanceState::GetRequestShapeValues(
13711401
.c_str());
13721402
}
13731403

1374-
int64_t element_cnt = backend::GetElementCount(shape, dims_count);
1404+
int64_t element_cnt = 0;
1405+
RETURN_IF_ERROR(
1406+
backend::GetElementCount(shape, dims_count, &element_cnt));
13751407
if (support_batching_) {
13761408
element_cnt /= shape[0];
13771409
}
@@ -1481,7 +1513,10 @@ ModelInstanceState::EvaluateTensorRTContext(
14811513
RETURN_IF_ERROR(TRITONBACKEND_InputProperties(
14821514
repr_input, nullptr, nullptr, &shape, &dims_count, nullptr,
14831515
nullptr));
1484-
shape_vec[0] += backend::GetElementCount(shape, dims_count);
1516+
int64_t element_cnt = 0;
1517+
RETURN_IF_ERROR(
1518+
backend::GetElementCount(shape, dims_count, &element_cnt));
1519+
shape_vec[0] += element_cnt;
14851520
}
14861521
auto err = ValidateDimension(
14871522
shape_vec, citr->second.min_dims_[io_index],
@@ -2462,7 +2497,8 @@ ModelInstanceState::InitializeConfigShapeOutputBindings(
24622497
context.context_->getTensorShape(io_name.c_str());
24632498
std::vector<int64_t> dim_vec;
24642499
DimsToDimVec(output_dim, &dim_vec);
2465-
int64_t byte_size = GetByteSize(dt, dim_vec);
2500+
int64_t byte_size = 0;
2501+
RETURN_IF_ERROR(GetByteSize(dt, dim_vec, &byte_size));
24662502

24672503
max_byte_size = std::max(max_byte_size, byte_size);
24682504
}
@@ -2691,13 +2727,13 @@ ModelInstanceState::InitializeExecuteInputBinding(
26912727

26922728
int64_t byte_size = 0;
26932729
if (io_binding_info.GetFormat().is_linear_format_) {
2694-
byte_size = GetByteSize(dt, maximum_dims);
2730+
RETURN_IF_ERROR(GetByteSize(dt, maximum_dims, &byte_size));
26952731
} else {
26962732
maximum_dims[io_binding_info.GetFormat().vectorized_dim_] +=
26972733
(io_binding_info.GetFormat().components_per_element_ -
26982734
(maximum_dims[io_binding_info.GetFormat().vectorized_dim_] %
26992735
io_binding_info.GetFormat().components_per_element_));
2700-
byte_size = GetByteSize(dt, maximum_dims);
2736+
RETURN_IF_ERROR(GetByteSize(dt, maximum_dims, &byte_size));
27012737
}
27022738

27032739
if (byte_size == -1) {
@@ -3097,7 +3133,7 @@ ModelInstanceState::InitializeShapeInputBinding(
30973133
std::vector<int64_t> dim_vec;
30983134
DimsToDimVec(
30993135
context.context_->getTensorShape(input_name.c_str()), &dim_vec);
3100-
byte_size = GetByteSize(dt, dim_vec);
3136+
RETURN_IF_ERROR(GetByteSize(dt, dim_vec, &byte_size));
31013137
} else {
31023138
auto component_count = GetElementCount(
31033139
context.context_->getTensorStrides(input_name.c_str()));

0 commit comments

Comments
 (0)