From dcc0c0a2efa7d6cf6a0e0698ef24df0ddcd2c1b6 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 1 May 2025 17:44:06 -0700 Subject: [PATCH] Initial commit --- src/onnxruntime.cc | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 40301da..acbb19d 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -2238,12 +2238,16 @@ ModelInstanceState::SetInputTensors( TRITONBACKEND_RequestInput(requests[idx], input_name, &input)); const int64_t* input_shape; uint32_t input_dims_count; + int64_t element_cnt = 0; RESPOND_AND_SET_NULL_IF_ERROR( &((*responses)[idx]), TRITONBACKEND_InputProperties( input, nullptr, nullptr, &input_shape, &input_dims_count, nullptr, nullptr)); + RESPOND_AND_SET_NULL_IF_ERROR( + &((*responses)[idx]), + GetElementCount(input_shape, input_dims_count, &element_cnt)); - batchn_shape[0] += GetElementCount(input_shape, input_dims_count); + batchn_shape[0] += element_cnt; } } // The shape for the entire input batch, [total_batch_size, ...] @@ -2402,8 +2406,10 @@ ModelInstanceState::SetStringInputTensor( expected_byte_sizes.push_back(0); expected_element_cnts.push_back(0); } else { - expected_element_cnts.push_back( - GetElementCount(input_shape, input_dims_count)); + int64_t element_cnt = 0; + RETURN_IF_ERROR( + GetElementCount(input_shape, input_dims_count, &element_cnt)); + expected_element_cnts.push_back(element_cnt); expected_byte_sizes.push_back(input_byte_size); } @@ -2573,8 +2579,9 @@ ModelInstanceState::ReadOutputTensor( ONNXTensorElementDataType type; RETURN_IF_ORT_ERROR(ort_api->GetTensorElementType(type_and_shape, &type)); if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - const size_t element_count = GetElementCount(batchn_shape); + int64_t element_count = 0; size_t total_length = 0; + RETURN_IF_ERROR(GetElementCount(batchn_shape, &element_count)); RETURN_IF_ORT_ERROR( ort_api->GetStringTensorDataLength(output_tensor, &total_length)); @@ -2776,7 +2783,9 @@ ModelInstanceState::SetStringBuffer( (*batchn_shape)[0] = shape[0]; } - const size_t expected_element_cnt = GetElementCount(*batchn_shape); + int64_t expected_element_cnt = 0; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, GetElementCount(*batchn_shape, &expected_element_cnt)); // If 'request' requested this output then copy it from // 'content'. If it did not request this output then just skip it