@@ -636,7 +636,11 @@ ModelInstanceState::Run(
636
636
" error setting the binding dimension" );
637
637
638
638
TRITONSERVER_DataType datatype = batch_input.DataType ();
639
- size_t total_byte_size = GetByteSize (datatype, shape);
639
+ int64_t total_byte_size = 0 ;
640
+ FAIL_ALL_AND_RETURN_IF_ERROR (
641
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
642
+ GetByteSize (datatype, shape, &total_byte_size),
643
+ " error getting the batch input byte size" );
640
644
641
645
const char * dst_buffer;
642
646
size_t dst_buffer_byte_size;
@@ -690,7 +694,12 @@ ModelInstanceState::Run(
690
694
" '" )
691
695
.c_str ());
692
696
693
- ragged_shape[0 ] += backend::GetElementCount (shape, dims_count);
697
+ int64_t element_cnt = 0 ;
698
+ FAIL_ALL_AND_RETURN_IF_ERROR (
699
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
700
+ backend::GetElementCount (shape, dims_count, &element_cnt),
701
+ " error getting the input element count" );
702
+ ragged_shape[0 ] += element_cnt;
694
703
if (req_idx == 0 ) {
695
704
datatype = temp_dt;
696
705
}
@@ -702,7 +711,11 @@ ModelInstanceState::Run(
702
711
name, ragged_shape, citr->second , io_index, &input_dims),
703
712
" error setting the binding dimension" );
704
713
705
- size_t total_byte_size = GetByteSize (datatype, ragged_shape);
714
+ int64_t total_byte_size = 0 ;
715
+ FAIL_ALL_AND_RETURN_IF_ERROR (
716
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
717
+ GetByteSize (datatype, ragged_shape, &total_byte_size),
718
+ " error getting the input byte size" );
706
719
707
720
payload_->collector_ ->ProcessTensor (
708
721
name.c_str (), static_cast <char *>(io_binding_info.GetBuffer ()),
@@ -758,17 +771,23 @@ ModelInstanceState::Run(
758
771
" error setting the binding dimension" );
759
772
}
760
773
761
- size_t total_byte_size = 0 ;
774
+ int64_t total_byte_size = 0 ;
762
775
if (io_binding_info.GetFormat ().is_linear_format_ ) {
763
- total_byte_size = GetByteSize (datatype, batchn_shape);
776
+ FAIL_ALL_AND_RETURN_IF_ERROR (
777
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
778
+ GetByteSize (datatype, batchn_shape, &total_byte_size),
779
+ " error getting the batch input byte size" );
764
780
// For input tensors with a linear IO format, the request has already
765
781
// verified the byte size, so no further validation is needed here.
766
782
} else {
767
783
batchn_shape[io_binding_info.GetFormat ().vectorized_dim_ ] +=
768
784
(io_binding_info.GetFormat ().components_per_element_ -
769
785
(batchn_shape[io_binding_info.GetFormat ().vectorized_dim_ ] %
770
786
io_binding_info.GetFormat ().components_per_element_ ));
771
- total_byte_size = GetByteSize (datatype, batchn_shape);
787
+ FAIL_ALL_AND_RETURN_IF_ERROR (
788
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
789
+ GetByteSize (datatype, batchn_shape, &total_byte_size),
790
+ " error getting the batch input byte size" );
772
791
773
792
// Ensure the request data byte size matches the expected byte size for
774
793
// non-linear IO format tensors
@@ -823,8 +842,13 @@ ModelInstanceState::Run(
823
842
// Initialize additional entries in batch input
824
843
if (io_binding_info.GetBatchInput () != nullptr ) {
825
844
const auto & batch_input = io_binding_info.GetBatchInput ()->first ;
826
- const size_t total_byte_size = GetByteSize (
827
- batch_input.DataType (), cuda_graph->input_dims_ [input_idx]);
845
+ int64_t total_byte_size = 0 ;
846
+ FAIL_ALL_AND_RETURN_IF_ERROR (
847
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
848
+ GetByteSize (
849
+ batch_input.DataType (), cuda_graph->input_dims_ [input_idx],
850
+ &total_byte_size),
851
+ " error getting the batch input byte size" );
828
852
829
853
auto & allocated_memory = io_binding_info.GetBatchInput ()->second ;
830
854
TRITONSERVER_MemoryType mem_type = allocated_memory->MemoryType ();
@@ -841,7 +865,7 @@ ModelInstanceState::Run(
841
865
batch_input, input_buffer, total_byte_size,
842
866
{{mem_type, mem_type_id}}, &dst_buffer, &dst_buffer_byte_size,
843
867
&dst_memory_type, &dst_memory_type_id),
844
- " error setting the bath input value" );
868
+ " error setting the batch input value" );
845
869
846
870
if ((batch_input.BatchInputKind () !=
847
871
BatchInput::Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE) &&
@@ -1067,8 +1091,10 @@ ModelInstanceState::Run(
1067
1091
batchn_shape[0 ] = shape[0 ];
1068
1092
}
1069
1093
1070
- const size_t tensor_element_cnt =
1071
- backend::GetElementCount (batchn_shape);
1094
+ int64_t tensor_element_cnt = 0 ;
1095
+ RESPOND_AND_SET_NULL_IF_ERROR (
1096
+ &response,
1097
+ backend::GetElementCount (batchn_shape, &tensor_element_cnt));
1072
1098
1073
1099
TRITONSERVER_DataType dt = ConvertTrtTypeToDataType (
1074
1100
engine_->getTensorDataType (name.c_str ()));
@@ -1112,7 +1138,11 @@ ModelInstanceState::Run(
1112
1138
// FIXME process reformat-free output, need to update output
1113
1139
// process code to accept batch1_byte_size and request batch
1114
1140
// size to break down output buffer properly.
1115
- size_t batch1_byte_size = GetByteSize (dt, batchn_shape);
1141
+ int64_t batch1_byte_size = 0 ;
1142
+ FAIL_ALL_AND_RETURN_IF_ERROR (
1143
+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
1144
+ GetByteSize (dt, batchn_shape, &batch1_byte_size),
1145
+ " error getting the batch byte size" );
1116
1146
if (support_batching_) {
1117
1147
batch1_byte_size /= payload_->total_batch_size_ ;
1118
1148
}
@@ -1371,7 +1401,9 @@ ModelInstanceState::GetRequestShapeValues(
1371
1401
.c_str ());
1372
1402
}
1373
1403
1374
- int64_t element_cnt = backend::GetElementCount (shape, dims_count);
1404
+ int64_t element_cnt = 0 ;
1405
+ RETURN_IF_ERROR (
1406
+ backend::GetElementCount (shape, dims_count, &element_cnt));
1375
1407
if (support_batching_) {
1376
1408
element_cnt /= shape[0 ];
1377
1409
}
@@ -1481,7 +1513,10 @@ ModelInstanceState::EvaluateTensorRTContext(
1481
1513
RETURN_IF_ERROR (TRITONBACKEND_InputProperties (
1482
1514
repr_input, nullptr , nullptr , &shape, &dims_count, nullptr ,
1483
1515
nullptr ));
1484
- shape_vec[0 ] += backend::GetElementCount (shape, dims_count);
1516
+ int64_t element_cnt = 0 ;
1517
+ RETURN_IF_ERROR (
1518
+ backend::GetElementCount (shape, dims_count, &element_cnt));
1519
+ shape_vec[0 ] += element_cnt;
1485
1520
}
1486
1521
auto err = ValidateDimension (
1487
1522
shape_vec, citr->second .min_dims_ [io_index],
@@ -2462,7 +2497,8 @@ ModelInstanceState::InitializeConfigShapeOutputBindings(
2462
2497
context.context_ ->getTensorShape (io_name.c_str ());
2463
2498
std::vector<int64_t > dim_vec;
2464
2499
DimsToDimVec (output_dim, &dim_vec);
2465
- int64_t byte_size = GetByteSize (dt, dim_vec);
2500
+ int64_t byte_size = 0 ;
2501
+ RETURN_IF_ERROR (GetByteSize (dt, dim_vec, &byte_size));
2466
2502
2467
2503
max_byte_size = std::max (max_byte_size, byte_size);
2468
2504
}
@@ -2691,13 +2727,13 @@ ModelInstanceState::InitializeExecuteInputBinding(
2691
2727
2692
2728
int64_t byte_size = 0 ;
2693
2729
if (io_binding_info.GetFormat ().is_linear_format_ ) {
2694
- byte_size = GetByteSize (dt, maximum_dims);
2730
+ RETURN_IF_ERROR ( GetByteSize (dt, maximum_dims, &byte_size) );
2695
2731
} else {
2696
2732
maximum_dims[io_binding_info.GetFormat ().vectorized_dim_ ] +=
2697
2733
(io_binding_info.GetFormat ().components_per_element_ -
2698
2734
(maximum_dims[io_binding_info.GetFormat ().vectorized_dim_ ] %
2699
2735
io_binding_info.GetFormat ().components_per_element_ ));
2700
- byte_size = GetByteSize (dt, maximum_dims);
2736
+ RETURN_IF_ERROR ( GetByteSize (dt, maximum_dims, &byte_size) );
2701
2737
}
2702
2738
2703
2739
if (byte_size == -1 ) {
@@ -3097,7 +3133,7 @@ ModelInstanceState::InitializeShapeInputBinding(
3097
3133
std::vector<int64_t > dim_vec;
3098
3134
DimsToDimVec (
3099
3135
context.context_ ->getTensorShape (input_name.c_str ()), &dim_vec);
3100
- byte_size = GetByteSize (dt, dim_vec);
3136
+ RETURN_IF_ERROR ( GetByteSize (dt, dim_vec, &byte_size) );
3101
3137
} else {
3102
3138
auto component_count = GetElementCount (
3103
3139
context.context_ ->getTensorStrides (input_name.c_str ()));
0 commit comments