@@ -72,7 +72,7 @@ static bool sameTableIDs(const std::unordered_set<table_id_t>& set,
72
72
return true ;
73
73
}
74
74
75
- static bool haveSameTableIDs (const std::vector<const LogicalOperator*>& ops,
75
+ static bool haveSameTableIDs (const std::vector<LogicalOperator*>& ops,
76
76
SemiMaskTargetType targetType) {
77
77
std::unordered_set<table_id_t > tableIDSet;
78
78
for (auto id : getTableIDs (ops[0 ], targetType)) {
@@ -86,7 +86,7 @@ static bool haveSameTableIDs(const std::vector<const LogicalOperator*>& ops,
86
86
return true ;
87
87
}
88
88
89
- static bool haveSameType (const std::vector<const LogicalOperator*>& ops) {
89
+ static bool haveSameType (const std::vector<LogicalOperator*>& ops) {
90
90
for (auto i = 0u ; i < ops.size (); ++i) {
91
91
if (ops[i]->getOperatorType () != ops[0 ]->getOperatorType ()) {
92
92
return false ;
@@ -95,7 +95,7 @@ static bool haveSameType(const std::vector<const LogicalOperator*>& ops) {
95
95
return true ;
96
96
}
97
97
98
- bool sanityCheckCandidates (const std::vector<const LogicalOperator*>& ops,
98
+ bool sanityCheckCandidates (const std::vector<LogicalOperator*>& ops,
99
99
SemiMaskTargetType targetType) {
100
100
KU_ASSERT (!ops.empty ());
101
101
if (!haveSameType (ops)) {
@@ -109,10 +109,13 @@ bool sanityCheckCandidates(const std::vector<const LogicalOperator*>& ops,
109
109
110
110
static std::shared_ptr<LogicalSemiMasker> appendSemiMasker (SemiMaskKeyType keyType,
111
111
SemiMaskTargetType targetType, std::shared_ptr<Expression> key,
112
- std::vector<const LogicalOperator*> candidates, std::shared_ptr<LogicalOperator> child) {
112
+ std::vector<LogicalOperator*> candidates, std::shared_ptr<LogicalOperator> child) {
113
113
auto tableIDs = getTableIDs (candidates[0 ], targetType);
114
114
auto semiMasker =
115
- std::make_shared<LogicalSemiMasker>(keyType, targetType, key, tableIDs, candidates, child);
115
+ std::make_shared<LogicalSemiMasker>(keyType, targetType, key, tableIDs, child);
116
+ for (auto candidate : candidates) {
117
+ semiMasker->addTarget (candidate);
118
+ }
116
119
semiMasker->computeFlatSchema ();
117
120
return semiMasker;
118
121
}
@@ -152,9 +155,9 @@ static bool isProbeSideQualified(LogicalOperator* probeRoot) {
152
155
153
156
// Find all ScanNodeIDs under root which scans parameter nodeID. Note that there might be
154
157
// multiple ScanNodeIDs matches because both node and rel table scans will trigger scanNodeIDs.
155
- static std::vector<const LogicalOperator*> getScanNodeCandidates (const Expression& nodeID,
158
+ static std::vector<LogicalOperator*> getScanNodeCandidates (const Expression& nodeID,
156
159
LogicalOperator* root) {
157
- std::vector<const LogicalOperator*> result;
160
+ std::vector<LogicalOperator*> result;
158
161
auto collector = LogicalScanNodeTableCollector ();
159
162
collector.collect (root);
160
163
for (auto & op : collector.getOperators ()) {
@@ -170,9 +173,9 @@ static std::vector<const LogicalOperator*> getScanNodeCandidates(const Expressio
170
173
return result;
171
174
}
172
175
173
- static std::vector<const LogicalOperator*> getRecursiveExtendInputNodeCandidates (
174
- const Expression& nodeID, LogicalOperator* root) {
175
- std::vector<const LogicalOperator*> result;
176
+ static std::vector<LogicalOperator*> getRecursiveExtendInputNodeCandidates (const Expression& nodeID,
177
+ LogicalOperator* root) {
178
+ std::vector<LogicalOperator*> result;
176
179
auto collector = LogicalRecursiveExtendCollector ();
177
180
collector.collect (root);
178
181
for (auto & op : collector.getOperators ()) {
@@ -185,9 +188,9 @@ static std::vector<const LogicalOperator*> getRecursiveExtendInputNodeCandidates
185
188
return result;
186
189
}
187
190
188
- static std::vector<const LogicalOperator*> getRecursiveExtendOutputNodeCandidates (
191
+ static std::vector<LogicalOperator*> getRecursiveExtendOutputNodeCandidates (
189
192
const Expression& nodeID, LogicalOperator* root) {
190
- std::vector<const LogicalOperator*> result;
193
+ std::vector<LogicalOperator*> result;
191
194
auto collector = LogicalRecursiveExtendCollector ();
192
195
collector.collect (root);
193
196
for (auto & op : collector.getOperators ()) {
@@ -207,13 +210,19 @@ static std::shared_ptr<LogicalOperator> tryApplySemiMask(std::shared_ptr<Express
207
210
auto recursiveExtendInputNodeCandidates =
208
211
getRecursiveExtendInputNodeCandidates (*nodeID, toRoot);
209
212
if (!recursiveExtendInputNodeCandidates.empty ()) {
213
+ for (auto & op : recursiveExtendInputNodeCandidates) {
214
+ op->cast <LogicalRecursiveExtend>().setInputNodeMask ();
215
+ }
210
216
auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_INPUT_NODE;
211
217
KU_ASSERT (sanityCheckCandidates (recursiveExtendInputNodeCandidates, targetType));
212
218
return appendSemiMasker (SemiMaskKeyType::NODE, targetType, std::move (nodeID),
213
219
recursiveExtendInputNodeCandidates, std::move (fromRoot));
214
220
}
215
221
auto recursiveExtendNodeCandidates = getRecursiveExtendOutputNodeCandidates (*nodeID, toRoot);
216
222
if (!recursiveExtendNodeCandidates.empty ()) {
223
+ for (auto & op : recursiveExtendNodeCandidates) {
224
+ op->cast <LogicalRecursiveExtend>().setOutputNodeMask ();
225
+ }
217
226
auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_OUTPUT_NODE;
218
227
KU_ASSERT (sanityCheckCandidates (recursiveExtendNodeCandidates, targetType));
219
228
return appendSemiMasker (SemiMaskKeyType::NODE, targetType, std::move (nodeID),
@@ -330,7 +339,7 @@ void HashJoinSIPOptimizer::visitIntersect(LogicalOperator* op) {
330
339
auto probeRoot = intersect.getChild (0 );
331
340
auto hasSemiMaskApplied = false ;
332
341
for (auto & nodeID : intersect.getKeyNodeIDs ()) {
333
- std::vector<const LogicalOperator*> ops;
342
+ std::vector<LogicalOperator*> ops;
334
343
for (auto i = 1u ; i < intersect.getNumChildren (); ++i) {
335
344
auto buildRoot = intersect.getChild (i);
336
345
for (auto & op_ : getScanNodeCandidates (*nodeID, buildRoot.get ())) {
@@ -367,7 +376,7 @@ void HashJoinSIPOptimizer::visitPathPropertyProbe(LogicalOperator* op) {
367
376
}
368
377
auto recursiveRel = pathPropertyProbe.getRel ();
369
378
auto nodeID = recursiveRel->getRecursiveInfo ()->node ->getInternalID ();
370
- std::vector<const LogicalOperator*> opsToApplySemiMask;
379
+ std::vector<LogicalOperator*> opsToApplySemiMask;
371
380
if (pathPropertyProbe.getNodeChild () != nullptr ) {
372
381
auto child = pathPropertyProbe.getNodeChild ().get ();
373
382
for (auto op_ : getScanNodeCandidates (*nodeID, child)) {
0 commit comments