@@ -70,6 +70,9 @@ BackendManager::BackendManager(SessionContext& session_context,
70
70
// Save the indexes of graph inputs among fused_node's inputDefs
71
71
// (which also contains initializers).
72
72
for (uint32_t index = 0 ; const auto & node : subgraph.GetInputs ()) {
73
+ if (subgraph.GetGraph ().GetConsumerNodes (node->Name ()).size () == 0 ) {
74
+ continue ; // Skip if the input is a dangling node
75
+ }
73
76
subgraph_context_.input_names .insert ({node->Name (), index++});
74
77
}
75
78
@@ -110,7 +113,7 @@ BackendManager::BackendManager(SessionContext& session_context,
110
113
subgraph_context_.has_dynamic_input_shape = true ;
111
114
LOGS_DEFAULT (INFO) << " [OpenVINO-EP] Model has symbolic input dims" ;
112
115
if (cpu_or_gpu || (npu && session_context_.enable_causallm ) &&
113
- !session_context_.disable_dynamic_shapes ) {
116
+ !session_context_.disable_dynamic_shapes ) {
114
117
LOGS_DEFAULT (INFO) << " [OpenVINO-EP] Starting backend initialization. "
115
118
<< " Creating backend Dynamic Shapes" ;
116
119
try {
@@ -291,24 +294,83 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod
291
294
}
292
295
293
296
bool BackendManager::ModelHasSymbolicInputDims (const onnxruntime::GraphViewer& subgraph) const {
294
- bool has_sym_dims = false ;
295
- auto graph_inputs = subgraph.GetInputs ();
296
- for (auto input : graph_inputs) {
297
+ const auto & graph_inputs = subgraph.GetInputs ();
298
+
299
+ // First validate shapes if provided by user
300
+ bool shapes_valid = true ;
301
+ if (!session_context_.reshape .empty ()) {
302
+ try {
303
+ ValidateInputShapes (session_context_.reshape , graph_inputs);
304
+ } catch (const std::exception& e) {
305
+ LOGS_DEFAULT (ERROR) << " [OpenVINO-EP] Shape validation failed: " << e.what ();
306
+ session_context_.reshape .clear (); // Clear the shape map as it's invalid
307
+ shapes_valid = false ;
308
+ }
309
+ }
310
+
311
+ // Count dynamic inputs and check if reshape covers all of them
312
+ size_t dynamic_input_count = 0 ;
313
+ bool all_dynamic_inputs_covered = true ;
314
+
315
+ for (const auto * input : graph_inputs) {
316
+ // Skip dangling inputs (no consumers)
317
+ if (subgraph.GetGraph ().GetConsumerNodes (input->Name ()).empty ()) {
318
+ continue ;
319
+ }
320
+
321
+ // Check if input has dynamic dimensions
322
+ bool has_dynamic_dim = false ;
323
+
324
+ // Case 1: Completely undefined shape
297
325
if (input->Shape () == nullptr ) {
298
- has_sym_dims = true ;
299
- break ;
326
+ has_dynamic_dim = true ;
300
327
}
301
- for (auto & dim : input->Shape ()->dim ()) {
302
- if (dim.value_case () != dim.kDimValue ) {
303
- has_sym_dims = true ;
304
- break ;
328
+ // Case 2: Shape defined but with symbolic dimensions
329
+ else {
330
+ for (const auto & dim : input->Shape ()->dim ()) {
331
+ if (dim.value_case () != dim.kDimValue ) {
332
+ has_dynamic_dim = true ;
333
+ break ;
334
+ }
305
335
}
306
336
}
307
- if (has_sym_dims) {
308
- break ;
337
+
338
+ // If dynamic, count it and check if reshape covers it
339
+ if (has_dynamic_dim) {
340
+ dynamic_input_count++;
341
+
342
+ // Check if this dynamic input is covered by reshape input
343
+ if (!session_context_.reshape .empty () &&
344
+ session_context_.reshape .find (input->Name ()) == session_context_.reshape .end ()) {
345
+ all_dynamic_inputs_covered = false ;
346
+ LOGS_DEFAULT (WARNING) << " [OpenVINO-EP] reshape_input is provided but doesn't cover dynamic input: "
347
+ << input->Name ();
348
+ }
309
349
}
310
350
}
311
- return has_sym_dims;
351
+
352
+ const bool has_symbolic_dims = (dynamic_input_count > 0 );
353
+
354
+ // Early return if no reshape input provided
355
+ if (session_context_.reshape .empty ()) {
356
+ return has_symbolic_dims; // Return based on whether model has symbolic dims
357
+ }
358
+
359
+ // For dynamic models with incomplete reshape coverage, clear shapes
360
+ if (has_symbolic_dims && !all_dynamic_inputs_covered) {
361
+ session_context_.reshape .clear ();
362
+ LOGS_DEFAULT (WARNING) << " reshape_input does not cover all dynamic dimensions, "
363
+ << " ignoring all provided shapes" ;
364
+ return true ; // Model is dynamic
365
+ }
366
+
367
+ // If shapes are valid with complete coverage for dynamic model, treat as concrete
368
+ if (has_symbolic_dims && shapes_valid && all_dynamic_inputs_covered) {
369
+ LOGS_DEFAULT (INFO) << " All dynamic dimensions successfully covered by reshape_input" ;
370
+ return false ; // Model is now effectively static with concrete shapes
371
+ }
372
+
373
+ return has_symbolic_dims; // Return dynamic status based on symbolic dimensions
312
374
}
313
375
314
376
// Check to see if the graph is QDQ
@@ -386,7 +448,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
386
448
const auto & onnx_model_path_name = subgraph.ModelPath ();
387
449
// QDQ stripping enabled only for the NPU and experimentally on the GPU
388
450
if ((session_context_.device_type .find (" NPU" ) != std::string::npos ||
389
- session_context_.device_type .find (" GPU" ) != std::string::npos) &&
451
+ session_context_.device_type .find (" GPU" ) != std::string::npos) &&
390
452
(enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts )) {
391
453
std::unique_ptr<onnxruntime::Model> model;
392
454
Status status = CreateModelWithStrippedQDQNodes (subgraph, logger, session_context_.so_share_ep_contexts , enable_ovep_qdq_optimizer, model, shared_context_.shared_weights );
@@ -480,6 +542,40 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p
480
542
return model_copy;
481
543
}
482
544
545
+ void BackendManager::ValidateInputShapes (const reshape_t & shapes,
546
+ const std::vector<const NodeArg*>& graph_inputs) const {
547
+ for (const auto & [tensor_name, requested_shape] : shapes) {
548
+ // Find matching input in graph
549
+ const NodeArg* graph_input = nullptr ;
550
+ for (const auto * input : graph_inputs) {
551
+ if (input->Name () == tensor_name) {
552
+ graph_input = input;
553
+ break ;
554
+ }
555
+ }
556
+
557
+ if (!graph_input) {
558
+ ORT_THROW (" Input '" + tensor_name + " ' specified in reshape_input does not exist in the graph" );
559
+ }
560
+
561
+ const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape ();
562
+ if (!graph_shape) {
563
+ ORT_THROW (" Graph input '" + tensor_name + " ' has no shape information" );
564
+ }
565
+
566
+ // Check dimensions count matches
567
+ size_t graph_dim_count = graph_shape->dim_size ();
568
+ size_t requested_dim_count = requested_shape.get_max_shape ().size ();
569
+
570
+ if (graph_dim_count != requested_dim_count) {
571
+ ORT_THROW (" Dimensions mismatch for input '" + tensor_name +
572
+ " ': graph expects " + std::to_string (graph_dim_count) +
573
+ " dimensions but reshape_input specifies " +
574
+ std::to_string (requested_dim_count) + " dimensions" );
575
+ }
576
+ }
577
+ }
578
+
483
579
void BackendManager::Compute (OrtKernelContext* context) {
484
580
Ort::KernelContext ctx (context);
485
581
std::chrono::high_resolution_clock::time_point start_compute, end_compute;
0 commit comments