@@ -341,6 +341,7 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit,
341
341
}
342
342
}
343
343
344
+ // this check is if QLinear node feed into the output of src graph which expects quantized output
344
345
static bool CheckQFeedsIntoQuantizedOutput (const NodeUnit& node_unit,
345
346
const std::unordered_map<std::string, std::string> graph_op_data_type) {
346
347
auto op_of_quantized_layer = node_unit.Outputs ();
@@ -447,9 +448,17 @@ static bool HandleDoubleQDQ(onnxruntime::Graph& dst_graph, const onnxruntime::Gr
447
448
static void AddStandaloneNodeUnit (onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph,
448
449
const NodeUnit& node_unit,
449
450
std::set<std::string>& initializers_to_keep,
450
- const logging::Logger& /* logger */ ) {
451
+ const logging::Logger& /* logger */ ,
452
+ bool IsWeightSharingWithoutOVEPQDQStripping) {
451
453
assert (node_unit.UnitType () == NodeUnit::Type::SingleNode);
452
454
455
+ // this is the scenario where WAI is enabled and ovep stripping is disabled
456
+ // do not strip off any Q or DQ node
457
+ if (IsWeightSharingWithoutOVEPQDQStripping) {
458
+ AddNode (initializers_to_keep, src_graph, dst_graph, node_unit.GetNode ());
459
+ return ;
460
+ }
461
+
453
462
if (HandleDoubleQDQ (dst_graph, src_graph, node_unit, initializers_to_keep)) return ;
454
463
455
464
auto add_identity_op = [&](bool duplicate_dq) {
@@ -511,7 +520,8 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph,
511
520
const onnxruntime::GraphViewer& src_graph,
512
521
const NodeUnit& node_unit,
513
522
std::set<std::string>& initializers_to_keep,
514
- const logging::Logger& /* logger */ ) {
523
+ const logging::Logger& /* logger */ ,
524
+ bool IsWeightSharingWithoutOVEPQDQStripping) {
515
525
assert (node_unit.UnitType () == NodeUnit::Type::QDQGroup);
516
526
517
527
// Collect inputs coming into the node unit.
@@ -529,7 +539,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph,
529
539
SkipReason reason = SkipReason::Other;
530
540
bool keep_dq = CheckDQRuleSet (node_unit, dq_node, src_graph, reason);
531
541
532
- if (keep_dq) {
542
+ if (IsWeightSharingWithoutOVEPQDQStripping || keep_dq) {
533
543
AddNode (initializers_to_keep, src_graph, dst_graph, *dq_node);
534
544
dq_node_args_to_keep.insert ({input_defs.at (0 )->Name (),
535
545
&dst_graph.GetOrCreateNodeArg (dq_node->OutputDefs ().at (0 )->Name (),
@@ -597,7 +607,7 @@ static void AddQDQNodeUnit(onnxruntime::Graph& dst_graph,
597
607
598
608
bool keep_q = CheckQRuleSet (node_unit, q_node, src_graph, reason);
599
609
600
- if (keep_q) {
610
+ if (IsWeightSharingWithoutOVEPQDQStripping || keep_q) {
601
611
AddNode (initializers_to_keep, src_graph, dst_graph, *q_node);
602
612
// if keep_q, then output defs of the target node doesn't change
603
613
output_args.push_back (&dst_graph.GetOrCreateNodeArg (target_node.OutputDefs ().at (i)->Name (),
@@ -675,7 +685,8 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
675
685
const logging::Logger& logger,
676
686
bool enable_ovep_weight_sharing,
677
687
/* out*/ std::unique_ptr<onnxruntime::Model>& model,
678
- /* out*/ sw& shared_weights) {
688
+ /* out*/ sw& shared_weights,
689
+ bool enable_ovep_qdq_optimizer) {
679
690
// NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc
680
691
// with the following differences:
681
692
// - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs.
@@ -766,10 +777,12 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
766
777
continue ; // Already handled this node unit
767
778
}
768
779
780
+ bool IsWeightSharingWithoutOVEPQDQStripping = enable_ovep_weight_sharing && !enable_ovep_qdq_optimizer;
781
+
769
782
if (node_unit->UnitType () == NodeUnit::Type::SingleNode) {
770
- AddStandaloneNodeUnit (dst_graph, src_graph, *node_unit, initializers_to_keep, logger);
783
+ AddStandaloneNodeUnit (dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping );
771
784
} else {
772
- AddQDQNodeUnit (dst_graph, src_graph, *node_unit, initializers_to_keep, logger);
785
+ AddQDQNodeUnit (dst_graph, src_graph, *node_unit, initializers_to_keep, logger, IsWeightSharingWithoutOVEPQDQStripping );
773
786
}
774
787
775
788
seen_node_units.insert (node_unit);
0 commit comments