@@ -361,7 +361,7 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
361
361
362
362
#ifdef IO_BUFFER_ENABLED
363
363
// Wait for Remote Aynchronous inference completion
364
- void BasicBackend::RemoteInfer (Ort::KernelContext& context, OVInferRequestPtr infer_request) {
364
+ void BasicBackend::RemoteInfer (Ort::KernelContext& context, OVInferRequestPtr infer_request) const {
365
365
try {
366
366
auto graph_input_info = exe_network_.Get ().inputs ();
367
367
int input_idx = 0 ;
@@ -467,7 +467,7 @@ void BasicBackend::RemoteInfer(Ort::KernelContext& context, OVInferRequestPtr in
467
467
}
468
468
#endif
469
469
470
- void BasicBackend::Infer (OrtKernelContext* ctx) {
470
+ void BasicBackend::Infer (OrtKernelContext* ctx) const {
471
471
Ort::KernelContext context (ctx);
472
472
473
473
LOGS_DEFAULT (INFO) << log_tag << " Running graph " << subgraph_context_.subgraph_name ;
@@ -492,43 +492,64 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
492
492
return ;
493
493
}
494
494
495
- bool gpu = session_context_.device_type .find (" GPU" ) != std::string::npos;
496
- bool cpu_or_gpu = gpu || (session_context_.device_type .find (" CPU" ) != std::string::npos);
497
-
498
495
// guarded_request will be released back to the pool when it goes out of scope
499
496
auto guarded_request = infer_req_pool_->getRequest ();
500
497
auto & infer_request = guarded_request.infer_request_ ;
501
498
#ifdef IO_BUFFER_ENABLED
502
- if (gpu &&
499
+ if (session_context_. device_type . find ( " GPU " ) != std::string::npos &&
503
500
(session_context_.context != nullptr ) && session_context_.is_wholly_supported_graph ) {
504
501
RemoteInfer (context, infer_request);
505
502
} else
506
503
#else
507
504
{ // scope for else if IO_BUFFER_ENABLED
508
505
509
- // Bind inputs
510
- for (const auto & input_info : bindings_->network_inputs_ ) {
511
- if (subgraph_context_.has_dynamic_input_shape &&
512
- !session_context_.disable_dynamic_shapes &&
513
- cpu_or_gpu) {
514
- // copy the input to set current shape.
515
- auto input_info_copy = input_info;
506
+ if (bindings_->has_dynamic_io_ ||
507
+ (subgraph_context_.has_dynamic_input_shape &&
508
+ !session_context_.disable_dynamic_shapes )) {
509
+ // Dynamic shape inference
510
+
511
+ // We don't know the output shapes so we need to get the outputs from the infer request and copy them into the ort
512
+ // tensors instead of binding them to the infer request directly.
513
+
514
+ // Bind inputs
515
+ for (const auto & input_info : bindings_->network_inputs_ ) {
516
+ // Set the input shape based on the input tensor from ort
516
517
auto tensor = context.GetInput (input_info.onnx_index );
517
- input_info_copy. shape = ParameterShape (tensor.GetTensorTypeAndShapeInfo ().GetShape ());
518
+ auto input_shape = ParameterShape (tensor.GetTensorTypeAndShapeInfo ().GetShape ());
518
519
519
- infer_request->SetTensor (input_info_copy, const_cast <void *>(tensor.GetTensorRawData ()));
520
- } else {
520
+ infer_request->SetTensorShapeOverride (input_info, input_shape, const_cast <void *>(tensor.GetTensorRawData ()));
521
+ }
522
+
523
+ // Run Inference
524
+ infer_request->Infer ();
525
+
526
+ // Copy outputs
527
+ for (const auto & output_info : bindings_->network_outputs_ ) {
528
+ auto ov_tensor = infer_request->GetTensor (output_info.name );
529
+ auto output_shape = ParameterShape::ToOnnxShape (ov_tensor->get_shape ());
530
+ auto ort_tensor = context.GetOutput (output_info.onnx_index , output_shape);
531
+
532
+ memcpy_s (ort_tensor.GetTensorMutableRawData (),
533
+ ort_tensor.GetTensorSizeInBytes (),
534
+ ov_tensor->data (),
535
+ ov_tensor->get_byte_size ());
536
+ }
537
+ } else {
538
+ // Static shape inference
539
+
540
+ // Bind inputs
541
+ for (const auto & input_info : bindings_->network_inputs_ ) {
521
542
infer_request->SetTensor (input_info, const_cast <void *>(context.GetInput (input_info.onnx_index ).GetTensorRawData ()));
522
543
}
523
- }
524
544
525
- // Bind outputs
526
- for (const auto & output_info : bindings_->network_outputs_ ) {
527
- infer_request->SetTensor (output_info, context.GetOutput (output_info.onnx_index , output_info.shape .onnx ()).GetTensorMutableRawData ());
528
- }
545
+ // Bind outputs
546
+ for (const auto & output_info : bindings_->network_outputs_ ) {
547
+ infer_request->SetTensor (output_info, context.GetOutput (output_info.onnx_index , output_info.shape .onnx ()).GetTensorMutableRawData ());
548
+ }
529
549
530
- // Run Inference
531
- infer_request->Infer ();
550
+ // Run Inference
551
+ infer_request->Infer ();
552
+ }
532
553
533
554
// Fill constant outputs if needed
534
555
for (const auto & [name, node] : const_outputs_map_) {
@@ -552,7 +573,7 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
552
573
553
574
#ifndef NDEBUG
554
575
#ifndef IO_BUFFER_ENABLED
555
- // Print performance counts before releasing the infer_request for potential thread safety
576
+ // Print performance counts before releasing the infer_request for thread safety
556
577
if (openvino_ep::backend_utils::IsDebugEnabled ()) {
557
578
std::string& hw_target = session_context_.device_type ;
558
579
printPerformanceCounts (infer_request, std::cout, hw_target);
0 commit comments