Skip to content

Commit b321100

Browse files
andyfengHKUray6080
andauthored
Remove enabled check from semi mask (#5221)
* Remove unnecessary enabled check from NodeOffsetSemiMask * Run clang-format --------- Co-authored-by: CI Bot <[email protected]> Co-authored-by: Guodong Jin <[email protected]> Co-authored-by: CI Bot <[email protected]>
1 parent 43c0cbd commit b321100

File tree

13 files changed

+97
-109
lines changed

13 files changed

+97
-109
lines changed

src/function/gds/gds_utils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ void GDSUtils::runFrontiersUntilConvergence(ExecutionContext* context, GDSComput
9999
compState.edgeCompute->resetSingleThreadState();
100100
while (frontierPair->continueNextIter(maxIteration)) {
101101
frontierPair->beginNewIteration();
102-
if (outputNodeMask != nullptr && outputNodeMask->enabled() &&
103-
compState.edgeCompute->terminate(*outputNodeMask)) {
102+
if (outputNodeMask != nullptr && compState.edgeCompute->terminate(*outputNodeMask)) {
104103
break;
105104
}
106105
runOnGraph(context, graph, extendDirection, compState, propertyToScan);

src/function/table/table_function.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::unique_ptr<TableFuncLocalState> TableFunction::initEmptyLocalState(
3636
}
3737

3838
std::unique_ptr<TableFuncSharedState> TableFunction::initEmptySharedState(
39-
const kuzu::function::TableFuncInitSharedStateInput& /*input*/) {
39+
const TableFuncInitSharedStateInput& /*input*/) {
4040
return std::make_unique<TableFuncSharedState>();
4141
}
4242

@@ -88,15 +88,15 @@ void TableFunction::getLogicalPlan(planner::Planner* planner,
8888
}
8989
}
9090

91-
std::unique_ptr<processor::PhysicalOperator> TableFunction::getPhysicalPlan(
92-
processor::PlanMapper* planMapper, const planner::LogicalOperator* logicalOp) {
93-
std::vector<processor::DataPos> outPosV;
94-
auto& call = logicalOp->constCast<planner::LogicalTableFunctionCall>();
91+
std::unique_ptr<PhysicalOperator> TableFunction::getPhysicalPlan(PlanMapper* planMapper,
92+
const LogicalOperator* logicalOp) {
93+
std::vector<DataPos> outPosV;
94+
auto& call = logicalOp->constCast<LogicalTableFunctionCall>();
9595
auto outSchema = call.getSchema();
9696
for (auto& expr : call.getBindData()->columns) {
9797
outPosV.emplace_back(planMapper->getDataPos(*expr, *outSchema));
9898
}
99-
auto info = processor::TableFunctionCallInfo();
99+
auto info = TableFunctionCallInfo();
100100
info.function = call.getTableFunc();
101101
info.bindData = call.getBindData()->copy();
102102
info.outPosV = outPosV;
@@ -110,9 +110,9 @@ std::unique_ptr<processor::PhysicalOperator> TableFunction::getPhysicalPlan(
110110
logicalSemiMasker->addTarget(logicalOp);
111111
}
112112
}
113-
auto printInfo = std::make_unique<processor::TableFunctionCallPrintInfo>(
114-
call.getTableFunc().name, call.getBindData()->columns);
115-
return std::make_unique<processor::TableFunctionCall>(std::move(info), sharedState,
113+
auto printInfo = std::make_unique<TableFunctionCallPrintInfo>(call.getTableFunc().name,
114+
call.getBindData()->columns);
115+
return std::make_unique<TableFunctionCall>(std::move(info), sharedState,
116116
planMapper->getOperatorID(), std::move(printInfo));
117117
}
118118

src/include/common/mask.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ struct SemiMaskUtil {
4040

4141
class NodeOffsetMaskMap {
4242
public:
43-
NodeOffsetMaskMap() : enabled_{false} {}
44-
45-
void enable() { enabled_ = true; }
46-
bool enabled() const { return enabled_; }
43+
NodeOffsetMaskMap() = default;
4744

4845
offset_t getNumMaskedNode() const;
4946

@@ -88,8 +85,6 @@ class NodeOffsetMaskMap {
8885
private:
8986
table_id_map_t<std::unique_ptr<SemiMask>> maskMap;
9087
SemiMask* pinnedMask = nullptr;
91-
// If mask map is enabled, then some nodes might be masked.
92-
bool enabled_;
9388
};
9489

9590
} // namespace common

src/include/planner/operator/extend/logical_recursive_extend.h

+17-10
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@ class LogicalRecursiveExtend final : public LogicalOperator {
1111

1212
public:
1313
LogicalRecursiveExtend(std::unique_ptr<function::RJAlgorithm> function,
14-
function::RJBindData bindData, binder::expression_vector resultColumns,
15-
common::table_id_set_t nbrTableIDSet)
14+
const function::RJBindData& bindData, binder::expression_vector resultColumns)
1615
: LogicalOperator{operatorType_}, function{std::move(function)}, bindData{bindData},
17-
resultColumns{std::move(resultColumns)}, nbrTableIDSet{std::move(nbrTableIDSet)},
18-
limitNum{common::INVALID_LIMIT} {}
16+
resultColumns{std::move(resultColumns)}, limitNum{common::INVALID_LIMIT} {}
1917

2018
void computeFlatSchema() override;
2119
void computeFactorizedSchema() override;
@@ -29,12 +27,15 @@ class LogicalRecursiveExtend final : public LogicalOperator {
2927
void setResultColumns(binder::expression_vector exprs) { resultColumns = std::move(exprs); }
3028
binder::expression_vector getResultColumns() const { return resultColumns; }
3129

32-
bool hasNbrTableIDSet() const { return !nbrTableIDSet.empty(); }
33-
common::table_id_set_t getNbrTableIDSet() const { return nbrTableIDSet; }
34-
3530
void setLimitNum(common::offset_t num) { limitNum = num; }
3631
common::offset_t getLimitNum() const { return limitNum; }
3732

33+
bool hasInputNodeMask() const { return hasInputNodeMask_; }
34+
void setInputNodeMask() { hasInputNodeMask_ = true; }
35+
36+
bool hasOutputNodeMask() const { return hasOutputNodeMask_; }
37+
void setOutputNodeMask() { hasOutputNodeMask_ = true; }
38+
3839
bool hasNodePredicate() const { return !children.empty(); }
3940
std::shared_ptr<LogicalOperator> getNodeMaskRoot() const {
4041
if (!children.empty()) {
@@ -46,17 +47,23 @@ class LogicalRecursiveExtend final : public LogicalOperator {
4647
std::string getExpressionsForPrinting() const override { return function->getFunctionName(); }
4748

4849
std::unique_ptr<LogicalOperator> copy() override {
49-
return std::make_unique<LogicalRecursiveExtend>(function->copy(), bindData, resultColumns,
50-
nbrTableIDSet);
50+
auto result =
51+
std::make_unique<LogicalRecursiveExtend>(function->copy(), bindData, resultColumns);
52+
result->limitNum = limitNum;
53+
result->hasInputNodeMask_ = hasInputNodeMask_;
54+
result->hasOutputNodeMask_ = hasOutputNodeMask_;
55+
return result;
5156
}
5257

5358
private:
5459
std::unique_ptr<function::RJAlgorithm> function;
5560
function::RJBindData bindData;
5661
binder::expression_vector resultColumns;
5762

58-
common::table_id_set_t nbrTableIDSet;
5963
common::offset_t limitNum; // TODO: remove this once recursive extend is pipelined.
64+
65+
bool hasInputNodeMask_ = false;
66+
bool hasOutputNodeMask_ = false;
6067
};
6168

6269
} // namespace planner

src/include/planner/operator/sip/logical_semi_masker.h

+1-8
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,11 @@ class LogicalSemiMasker final : public LogicalOperator {
6161
static constexpr LogicalOperatorType type_ = LogicalOperatorType::SEMI_MASKER;
6262

6363
public:
64-
// Constructor that does not specify operators accepting semi masks. Later stage there must be
65-
// logics filling "targetOps" field.
6664
LogicalSemiMasker(SemiMaskKeyType keyType, SemiMaskTargetType targetType,
6765
std::shared_ptr<binder::Expression> key, std::vector<common::table_id_t> nodeTableIDs,
6866
std::shared_ptr<LogicalOperator> child)
69-
: LogicalSemiMasker{keyType, targetType, std::move(key), std::move(nodeTableIDs),
70-
std::vector<const LogicalOperator*>{}, std::move(child)} {}
71-
LogicalSemiMasker(SemiMaskKeyType keyType, SemiMaskTargetType targetType,
72-
std::shared_ptr<binder::Expression> key, std::vector<common::table_id_t> nodeTableIDs,
73-
std::vector<const LogicalOperator*> ops, std::shared_ptr<LogicalOperator> child)
7467
: LogicalOperator{type_, std::move(child)}, keyType{keyType}, targetType{targetType},
75-
key{std::move(key)}, nodeTableIDs{std::move(nodeTableIDs)}, targetOps{std::move(ops)} {}
68+
key{std::move(key)}, nodeTableIDs{std::move(nodeTableIDs)} {}
7669

7770
void computeFactorizedSchema() override { copyChildSchema(0); }
7871
void computeFlatSchema() override { copyChildSchema(0); }

src/include/processor/operator/recursive_extend_shared_state.h

-11
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ struct RecursiveExtendSharedState {
2323
void setInputNodeMask(std::unique_ptr<common::NodeOffsetMaskMap> maskMap) {
2424
inputNodeMask = std::move(maskMap);
2525
}
26-
void enableInputNodeMask() { inputNodeMask->enable(); }
2726
common::NodeOffsetMaskMap* getInputNodeMaskMap() const { return inputNodeMask.get(); }
2827

2928
void setOutputNodeMask(std::unique_ptr<common::NodeOffsetMaskMap> maskMap) {
3029
outputNodeMask = std::move(maskMap);
3130
}
32-
void enableOutputNodeMask() { outputNodeMask->enable(); }
3331
common::NodeOffsetMaskMap* getOutputNodeMaskMap() const { return outputNodeMask.get(); }
3432

3533
void setPathNodeMask(std::unique_ptr<common::NodeOffsetMaskMap> maskMap) {
@@ -39,22 +37,13 @@ struct RecursiveExtendSharedState {
3937

4038
bool exceedLimit() const { return !(counter == nullptr) && counter->exceedLimit(); }
4139

42-
void setNbrTableIDSet(common::table_id_set_t set) { nbrTableIDSet = std::move(set); }
43-
bool inNbrTableIDs(common::table_id_t tableID) const {
44-
if (nbrTableIDSet.empty()) {
45-
return true;
46-
}
47-
return nbrTableIDSet.contains(tableID);
48-
}
49-
5040
public:
5141
FactorizedTablePool factorizedTablePool;
5242

5343
private:
5444
std::unique_ptr<common::NodeOffsetMaskMap> inputNodeMask = nullptr;
5545
std::unique_ptr<common::NodeOffsetMaskMap> outputNodeMask = nullptr;
5646
std::unique_ptr<common::NodeOffsetMaskMap> pathNodeMask = nullptr;
57-
common::table_id_set_t nbrTableIDSet;
5847
};
5948

6049
} // namespace processor

src/include/processor/plan_mapper.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,9 @@ class PlanMapper {
237237
static FactorizedTableSchema createFlatFTableSchema(
238238
const binder::expression_vector& expressions, const planner::Schema& schema);
239239
std::unique_ptr<common::SemiMask> createSemiMask(common::table_id_t tableID) const;
240-
std::unique_ptr<common::NodeOffsetMaskMap> createNodeOffsetMaskMap(
241-
const binder::Expression& expr) const;
242240

243241
public:
244-
processor::ExecutionContext* executionContext;
242+
ExecutionContext* executionContext;
245243
main::ClientContext* clientContext;
246244
std::unordered_map<const planner::LogicalOperator*, PhysicalOperator*> logicalOpToPhysicalOpMap;
247245

src/optimizer/acc_hash_join_optimizer.cpp

+23-14
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static bool sameTableIDs(const std::unordered_set<table_id_t>& set,
7272
return true;
7373
}
7474

75-
static bool haveSameTableIDs(const std::vector<const LogicalOperator*>& ops,
75+
static bool haveSameTableIDs(const std::vector<LogicalOperator*>& ops,
7676
SemiMaskTargetType targetType) {
7777
std::unordered_set<table_id_t> tableIDSet;
7878
for (auto id : getTableIDs(ops[0], targetType)) {
@@ -86,7 +86,7 @@ static bool haveSameTableIDs(const std::vector<const LogicalOperator*>& ops,
8686
return true;
8787
}
8888

89-
static bool haveSameType(const std::vector<const LogicalOperator*>& ops) {
89+
static bool haveSameType(const std::vector<LogicalOperator*>& ops) {
9090
for (auto i = 0u; i < ops.size(); ++i) {
9191
if (ops[i]->getOperatorType() != ops[0]->getOperatorType()) {
9292
return false;
@@ -95,7 +95,7 @@ static bool haveSameType(const std::vector<const LogicalOperator*>& ops) {
9595
return true;
9696
}
9797

98-
bool sanityCheckCandidates(const std::vector<const LogicalOperator*>& ops,
98+
bool sanityCheckCandidates(const std::vector<LogicalOperator*>& ops,
9999
SemiMaskTargetType targetType) {
100100
KU_ASSERT(!ops.empty());
101101
if (!haveSameType(ops)) {
@@ -109,10 +109,13 @@ bool sanityCheckCandidates(const std::vector<const LogicalOperator*>& ops,
109109

110110
static std::shared_ptr<LogicalSemiMasker> appendSemiMasker(SemiMaskKeyType keyType,
111111
SemiMaskTargetType targetType, std::shared_ptr<Expression> key,
112-
std::vector<const LogicalOperator*> candidates, std::shared_ptr<LogicalOperator> child) {
112+
std::vector<LogicalOperator*> candidates, std::shared_ptr<LogicalOperator> child) {
113113
auto tableIDs = getTableIDs(candidates[0], targetType);
114114
auto semiMasker =
115-
std::make_shared<LogicalSemiMasker>(keyType, targetType, key, tableIDs, candidates, child);
115+
std::make_shared<LogicalSemiMasker>(keyType, targetType, key, tableIDs, child);
116+
for (auto candidate : candidates) {
117+
semiMasker->addTarget(candidate);
118+
}
116119
semiMasker->computeFlatSchema();
117120
return semiMasker;
118121
}
@@ -152,9 +155,9 @@ static bool isProbeSideQualified(LogicalOperator* probeRoot) {
152155

153156
// Find all ScanNodeIDs under root which scans parameter nodeID. Note that there might be
154157
// multiple ScanNodeIDs matches because both node and rel table scans will trigger scanNodeIDs.
155-
static std::vector<const LogicalOperator*> getScanNodeCandidates(const Expression& nodeID,
158+
static std::vector<LogicalOperator*> getScanNodeCandidates(const Expression& nodeID,
156159
LogicalOperator* root) {
157-
std::vector<const LogicalOperator*> result;
160+
std::vector<LogicalOperator*> result;
158161
auto collector = LogicalScanNodeTableCollector();
159162
collector.collect(root);
160163
for (auto& op : collector.getOperators()) {
@@ -170,9 +173,9 @@ static std::vector<const LogicalOperator*> getScanNodeCandidates(const Expressio
170173
return result;
171174
}
172175

173-
static std::vector<const LogicalOperator*> getRecursiveExtendInputNodeCandidates(
174-
const Expression& nodeID, LogicalOperator* root) {
175-
std::vector<const LogicalOperator*> result;
176+
static std::vector<LogicalOperator*> getRecursiveExtendInputNodeCandidates(const Expression& nodeID,
177+
LogicalOperator* root) {
178+
std::vector<LogicalOperator*> result;
176179
auto collector = LogicalRecursiveExtendCollector();
177180
collector.collect(root);
178181
for (auto& op : collector.getOperators()) {
@@ -185,9 +188,9 @@ static std::vector<const LogicalOperator*> getRecursiveExtendInputNodeCandidates
185188
return result;
186189
}
187190

188-
static std::vector<const LogicalOperator*> getRecursiveExtendOutputNodeCandidates(
191+
static std::vector<LogicalOperator*> getRecursiveExtendOutputNodeCandidates(
189192
const Expression& nodeID, LogicalOperator* root) {
190-
std::vector<const LogicalOperator*> result;
193+
std::vector<LogicalOperator*> result;
191194
auto collector = LogicalRecursiveExtendCollector();
192195
collector.collect(root);
193196
for (auto& op : collector.getOperators()) {
@@ -207,13 +210,19 @@ static std::shared_ptr<LogicalOperator> tryApplySemiMask(std::shared_ptr<Express
207210
auto recursiveExtendInputNodeCandidates =
208211
getRecursiveExtendInputNodeCandidates(*nodeID, toRoot);
209212
if (!recursiveExtendInputNodeCandidates.empty()) {
213+
for (auto& op : recursiveExtendInputNodeCandidates) {
214+
op->cast<LogicalRecursiveExtend>().setInputNodeMask();
215+
}
210216
auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_INPUT_NODE;
211217
KU_ASSERT(sanityCheckCandidates(recursiveExtendInputNodeCandidates, targetType));
212218
return appendSemiMasker(SemiMaskKeyType::NODE, targetType, std::move(nodeID),
213219
recursiveExtendInputNodeCandidates, std::move(fromRoot));
214220
}
215221
auto recursiveExtendNodeCandidates = getRecursiveExtendOutputNodeCandidates(*nodeID, toRoot);
216222
if (!recursiveExtendNodeCandidates.empty()) {
223+
for (auto& op : recursiveExtendNodeCandidates) {
224+
op->cast<LogicalRecursiveExtend>().setOutputNodeMask();
225+
}
217226
auto targetType = SemiMaskTargetType::RECURSIVE_EXTEND_OUTPUT_NODE;
218227
KU_ASSERT(sanityCheckCandidates(recursiveExtendNodeCandidates, targetType));
219228
return appendSemiMasker(SemiMaskKeyType::NODE, targetType, std::move(nodeID),
@@ -330,7 +339,7 @@ void HashJoinSIPOptimizer::visitIntersect(LogicalOperator* op) {
330339
auto probeRoot = intersect.getChild(0);
331340
auto hasSemiMaskApplied = false;
332341
for (auto& nodeID : intersect.getKeyNodeIDs()) {
333-
std::vector<const LogicalOperator*> ops;
342+
std::vector<LogicalOperator*> ops;
334343
for (auto i = 1u; i < intersect.getNumChildren(); ++i) {
335344
auto buildRoot = intersect.getChild(i);
336345
for (auto& op_ : getScanNodeCandidates(*nodeID, buildRoot.get())) {
@@ -367,7 +376,7 @@ void HashJoinSIPOptimizer::visitPathPropertyProbe(LogicalOperator* op) {
367376
}
368377
auto recursiveRel = pathPropertyProbe.getRel();
369378
auto nodeID = recursiveRel->getRecursiveInfo()->node->getInternalID();
370-
std::vector<const LogicalOperator*> opsToApplySemiMask;
379+
std::vector<LogicalOperator*> opsToApplySemiMask;
371380
if (pathPropertyProbe.getNodeChild() != nullptr) {
372381
auto child = pathPropertyProbe.getNodeChild().get();
373382
for (auto op_ : getScanNodeCandidates(*nodeID, child)) {

src/planner/plan/append_extend.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -118,24 +118,9 @@ void Planner::appendRecursiveExtend(const std::shared_ptr<NodeExpression>& bound
118118
*recursiveInfo->node, recursiveInfo->nodePredicate);
119119
nodeMaskRoot = p.getLastOperator();
120120
}
121-
// E.g. Given schema person-knows->person & person-knows->animal
122-
// And query MATCH (a:person:animal)-[e*]->(b:person)
123-
// The destination node b after GDS will contain both person & animal label. We need to prune
124-
// the animal out.
125-
common::table_id_set_t nbrTableIDSet;
126-
auto targetNbrTableIDSet = nbrNode->getTableIDsSet();
127-
auto recursiveNbrTableIDSet = recursiveInfo->node->getTableIDsSet();
128-
for (auto& tableID : recursiveNbrTableIDSet) {
129-
if (targetNbrTableIDSet.contains(tableID)) {
130-
nbrTableIDSet.insert(tableID);
131-
}
132-
}
133-
if (nbrTableIDSet.size() >= recursiveNbrTableIDSet.size()) {
134-
nbrTableIDSet.clear(); // No need to prune nbr table id.
135-
}
136121
auto probePlan = LogicalPlan();
137122
auto recursiveExtend = std::make_shared<LogicalRecursiveExtend>(recursiveInfo->function->copy(),
138-
*recursiveInfo->bindData, resultColumns, nbrTableIDSet);
123+
*recursiveInfo->bindData, resultColumns);
139124
if (nodeMaskRoot != nullptr) {
140125
recursiveExtend->addChild(nodeMaskRoot);
141126
}

src/processor/map/map_recursive_extend.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "binder/expression/node_expression.h"
12
#include "graph/on_disk_graph.h"
23
#include "planner/operator/extend/logical_recursive_extend.h"
34
#include "planner/operator/sip/logical_semi_masker.h"
@@ -12,6 +13,16 @@ using namespace kuzu::common;
1213
namespace kuzu {
1314
namespace processor {
1415

16+
std::unique_ptr<NodeOffsetMaskMap> createNodeOffsetMaskMap(const Expression& expr,
17+
PlanMapper* mapper) {
18+
auto& node = expr.constCast<NodeExpression>();
19+
auto maskMap = std::make_unique<NodeOffsetMaskMap>();
20+
for (auto tableID : node.getTableIDs()) {
21+
maskMap->addMask(tableID, mapper->createSemiMask(tableID));
22+
}
23+
return maskMap;
24+
}
25+
1526
std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(
1627
const LogicalOperator* logicalOperator) {
1728
auto& extend = logicalOperator->constCast<LogicalRecursiveExtend>();
@@ -23,11 +34,12 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(
2334
auto graph = std::make_unique<OnDiskGraph>(clientContext, bindData.graphEntry.copy());
2435
auto sharedState =
2536
std::make_shared<RecursiveExtendSharedState>(table, std::move(graph), extend.getLimitNum());
26-
if (extend.hasNbrTableIDSet()) {
27-
sharedState->setNbrTableIDSet(extend.getNbrTableIDSet());
37+
if (extend.hasInputNodeMask()) {
38+
sharedState->setInputNodeMask(createNodeOffsetMaskMap(*bindData.nodeInput, this));
39+
}
40+
if (extend.hasOutputNodeMask()) {
41+
sharedState->setOutputNodeMask(createNodeOffsetMaskMap(*bindData.nodeOutput, this));
2842
}
29-
sharedState->setInputNodeMask(createNodeOffsetMaskMap(*bindData.nodeInput));
30-
sharedState->setOutputNodeMask(createNodeOffsetMaskMap(*bindData.nodeOutput));
3143
auto printInfo =
3244
std::make_unique<RecursiveExtendPrintInfo>(extend.getFunction().getFunctionName());
3345
auto descriptor = std::make_unique<ResultSetDescriptor>();

0 commit comments

Comments
 (0)