Skip to content

Commit 4f32f22

Browse files
committed
Fix dynamically sized i/o
1 parent 12dd635 commit 4f32f22

File tree

4 files changed

+62
-38
lines changed

4 files changed

+62
-38
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
361361

362362
#ifdef IO_BUFFER_ENABLED
363363
// Wait for Remote Aynchronous inference completion
364-
void BasicBackend::RemoteInfer(Ort::KernelContext& context, OVInferRequestPtr infer_request) {
364+
void BasicBackend::RemoteInfer(Ort::KernelContext& context, OVInferRequestPtr infer_request) const {
365365
try {
366366
auto graph_input_info = exe_network_.Get().inputs();
367367
int input_idx = 0;
@@ -467,7 +467,7 @@ void BasicBackend::RemoteInfer(Ort::KernelContext& context, OVInferRequestPtr in
467467
}
468468
#endif
469469

470-
void BasicBackend::Infer(OrtKernelContext* ctx) {
470+
void BasicBackend::Infer(OrtKernelContext* ctx) const {
471471
Ort::KernelContext context(ctx);
472472

473473
LOGS_DEFAULT(INFO) << log_tag << "Running graph " << subgraph_context_.subgraph_name;
@@ -492,43 +492,64 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
492492
return;
493493
}
494494

495-
bool gpu = session_context_.device_type.find("GPU") != std::string::npos;
496-
bool cpu_or_gpu = gpu || (session_context_.device_type.find("CPU") != std::string::npos);
497-
498495
// guarded_request will be released back to the pool when it goes out of scope
499496
auto guarded_request = infer_req_pool_->getRequest();
500497
auto& infer_request = guarded_request.infer_request_;
501498
#ifdef IO_BUFFER_ENABLED
502-
if (gpu &&
499+
if (session_context_.device_type.find("GPU") != std::string::npos &&
503500
(session_context_.context != nullptr) && session_context_.is_wholly_supported_graph) {
504501
RemoteInfer(context, infer_request);
505502
} else
506503
#else
507504
{ // scope for else if IO_BUFFER_ENABLED
508505

509-
// Bind inputs
510-
for (const auto& input_info : bindings_->network_inputs_) {
511-
if (subgraph_context_.has_dynamic_input_shape &&
512-
!session_context_.disable_dynamic_shapes &&
513-
cpu_or_gpu) {
514-
// copy the input to set current shape.
515-
auto input_info_copy = input_info;
506+
if (bindings_->has_dynamic_io_ ||
507+
(subgraph_context_.has_dynamic_input_shape &&
508+
!session_context_.disable_dynamic_shapes)) {
509+
// Dynamic shape inference
510+
511+
// We don't know the output shapes so we need to get the outputs from the infer request and copy them into the ort
512+
// tensors instead of binding them to the infer request directly.
513+
514+
// Bind inputs
515+
for (const auto& input_info : bindings_->network_inputs_) {
516+
// Set the input shape based on the input tensor from ort
516517
auto tensor = context.GetInput(input_info.onnx_index);
517-
input_info_copy.shape = ParameterShape(tensor.GetTensorTypeAndShapeInfo().GetShape());
518+
auto input_shape = ParameterShape(tensor.GetTensorTypeAndShapeInfo().GetShape());
518519

519-
infer_request->SetTensor(input_info_copy, const_cast<void*>(tensor.GetTensorRawData()));
520-
} else {
520+
infer_request->SetTensorShapeOverride(input_info, input_shape, const_cast<void*>(tensor.GetTensorRawData()));
521+
}
522+
523+
// Run Inference
524+
infer_request->Infer();
525+
526+
// Copy outputs
527+
for (const auto& output_info : bindings_->network_outputs_) {
528+
auto ov_tensor = infer_request->GetTensor(output_info.name);
529+
auto output_shape = ParameterShape::ToOnnxShape(ov_tensor->get_shape());
530+
auto ort_tensor = context.GetOutput(output_info.onnx_index, output_shape);
531+
532+
memcpy_s(ort_tensor.GetTensorMutableRawData(),
533+
ort_tensor.GetTensorSizeInBytes(),
534+
ov_tensor->data(),
535+
ov_tensor->get_byte_size());
536+
}
537+
} else {
538+
// Static shape inference
539+
540+
// Bind inputs
541+
for (const auto& input_info : bindings_->network_inputs_) {
521542
infer_request->SetTensor(input_info, const_cast<void*>(context.GetInput(input_info.onnx_index).GetTensorRawData()));
522543
}
523-
}
524544

525-
// Bind outputs
526-
for (const auto& output_info : bindings_->network_outputs_) {
527-
infer_request->SetTensor(output_info, context.GetOutput(output_info.onnx_index, output_info.shape.onnx()).GetTensorMutableRawData());
528-
}
545+
// Bind outputs
546+
for (const auto& output_info : bindings_->network_outputs_) {
547+
infer_request->SetTensor(output_info, context.GetOutput(output_info.onnx_index, output_info.shape.onnx()).GetTensorMutableRawData());
548+
}
529549

530-
// Run Inference
531-
infer_request->Infer();
550+
// Run Inference
551+
infer_request->Infer();
552+
}
532553

533554
// Fill constant outputs if needed
534555
for (const auto& [name, node] : const_outputs_map_) {
@@ -552,7 +573,7 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
552573

553574
#ifndef NDEBUG
554575
#ifndef IO_BUFFER_ENABLED
555-
// Print performance counts before releasing the infer_request for potential thread safety
576+
// Print performance counts before releasing the infer_request for thread safety
556577
if (openvino_ep::backend_utils::IsDebugEnabled()) {
557578
std::string& hw_target = session_context_.device_type;
558579
printPerformanceCounts(infer_request, std::cout, hw_target);

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace openvino_ep {
2828
struct OnnxToOvNetworkBindings {
2929
std::vector<ParameterInfo> network_outputs_;
3030
std::vector<ParameterInfo> network_inputs_;
31+
bool has_dynamic_io_ = false;
3132

3233
OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context) {
3334
auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) {
@@ -42,6 +43,9 @@ struct OnnxToOvNetworkBindings {
4243
auto ov_param_index = std::distance(ov_parameters.begin(), it);
4344

4445
auto shape = ov_parameters[ov_param_index].get_partial_shape();
46+
if (shape.is_dynamic()) {
47+
has_dynamic_io_ = true;
48+
}
4549
auto type = ov_parameters[ov_param_index].get_element_type();
4650
ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, shape};
4751
input_output_map.push_back(std::move(info));
@@ -62,7 +66,7 @@ class BasicBackend : public IBackend {
6266
SharedContext& shared_context,
6367
ptr_stream_t& model_stream);
6468

65-
void Infer(OrtKernelContext* context) override;
69+
void Infer(OrtKernelContext* context) const override;
6670
~BasicBackend() override = default;
6771
ov::CompiledModel GetOVCompiledModel() override {
6872
return exe_network_.Get();
@@ -77,13 +81,12 @@ class BasicBackend : public IBackend {
7781
void SetNumThreads(ov::AnyMap& device_config);
7882

7983
#ifdef IO_BUFFER_ENABLED
80-
void RemoteInfer(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
84+
void RemoteInfer(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request) const;
8185
#endif
8286

8387
SessionContext& session_context_;
8488
SubGraphContext subgraph_context_;
8589
SharedContext& shared_context_;
86-
mutable std::mutex compute_lock_;
8790
OVExeNetwork exe_network_;
8891
std::map<std::string, std::shared_ptr<ov::Node>> const_outputs_map_;
8992
std::unique_ptr<InferRequestPool> infer_req_pool_;
@@ -92,7 +95,7 @@ class BasicBackend : public IBackend {
9295
#endif
9396

9497
using ort_tensor_key_t = const std::string;
95-
std::unique_ptr<OnnxToOvNetworkBindings> bindings_;
98+
std::unique_ptr<const OnnxToOvNetworkBindings> bindings_;
9699
};
97100

98101
class InferRequestPool {

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ namespace openvino_ep {
1414

1515
class IBackend {
1616
public:
17-
virtual void Infer(OrtKernelContext* context) = 0;
18-
virtual ov::CompiledModel GetOVCompiledModel() = 0;
17+
virtual void Infer(OrtKernelContext* context) const = 0;
18+
virtual ov::CompiledModel& GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
2020
};
2121
using ptr_stream_t = std::unique_ptr<std::istream>;

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ typedef ov::intel_gpu::ocl::ClContext* OVRemoteContextPtr;
3838
typedef ov::RemoteContext OVRemoteContext;
3939
#endif
4040

41-
4241
struct ParameterShape {
4342
using onnx_shape_t = std::vector<int64_t>;
4443

@@ -55,16 +54,12 @@ struct ParameterShape {
5554
return ov::PartialShape(ov_shape);
5655
}
5756

58-
static ov::Shape ToOvShape(const onnx_shape_t& onnx_shape) {
59-
return ToOvPartialShape(onnx_shape).get_shape();
60-
}
61-
6257
static onnx_shape_t ToOnnxShape(const ov::PartialShape& ov_shape) {
63-
onnx_shape_t onnx_shape(ov_shape.size());
58+
onnx_shape_t onnx_shape(ov_shape.size());
6459
std::transform(ov_shape.begin(), ov_shape.end(), onnx_shape.begin(), [](const auto& dim) {
6560
return dim.is_dynamic() ? -1 : dim.get_length();
6661
});
67-
return onnx_shape;
62+
return onnx_shape;
6863
}
6964

7065
static bool IsDynamic(const ov::PartialShape& ov_shape) {
@@ -189,9 +184,14 @@ class OVInferRequest {
189184

190185
// Set tensor described param_info and ort_ptr. Call infer req tensor if ort_ptr is last set.
191186
void SetTensor(const ParameterInfo& param_info, void* ort_ptr) {
187+
SetTensorShapeOverride(param_info, param_info.shape, ort_ptr);
188+
}
189+
190+
// Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set.
191+
void SetTensorShapeOverride(const ParameterInfo& param_info, const ParameterShape& shape_override, void* ort_ptr) {
192192
auto& cached_binding = bindings_cache_[param_info.name];
193193
if (cached_binding.ort_ptr != ort_ptr) {
194-
auto tensor_ptr = std::make_shared<ov::Tensor>(param_info.type, param_info.shape.ov_shape(), const_cast<void*>(ort_ptr));
194+
auto tensor_ptr = std::make_shared<ov::Tensor>(param_info.type, shape_override.ov_shape(), const_cast<void*>(ort_ptr));
195195
SetTensor(param_info.name, tensor_ptr);
196196
cached_binding = {tensor_ptr, ort_ptr};
197197
}

0 commit comments

Comments
 (0)