Skip to content

Track #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,26 @@ class VarNode : public ExprNode {

RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);

/*! \brief Hash Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarHash to hash Var by id.
*/
struct VarHash {
size_t operator()(const Var& v) const {
return v->vid.hash();
}
};

/*! \brief Compare Var by it's id.
* Different VarNode might has same vid, and they are considered to be the same var in such case.
* Use VarEqual to compare Var by id.
*/
struct VarEqual {
bool operator()(const Var& l, const Var& r) const {
return l->vid.get() == r->vid.get();
}
};

/*!
* \brief Global variable that leaves in the top-level module.
* This is used to enable recursive calls between function.
Expand Down Expand Up @@ -521,7 +541,7 @@ RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
* specific kind of TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
Expand All @@ -539,6 +559,25 @@ class TempExprNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

class Annotate;
class AnnotateNode : public ExprNode {
public:
Expr expr;
NodeRef annotation;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("annotation", &annotation);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Annotate make(Expr expr, NodeRef annotation);

static constexpr const char* _type_key = "relay.AnnotateNode";
TVM_DECLARE_NODE_TYPE_INFO(AnnotateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Annotate, AnnotateNode, Expr);

// implementataions
inline const Type& ExprNode::checked_type() const {
CHECK(checked_type_.defined()) << "internal error: the type checker has "
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AnnotateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -140,6 +141,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
RELAY_EXPR_FUNCTOR_DISPATCH(AnnotateNode);
return vtable;
}
};
Expand Down Expand Up @@ -170,6 +172,7 @@ class ExprVisitor
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
void VisitExpr_(const AnnotateNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
Expand Down Expand Up @@ -212,6 +215,7 @@ class ExprMutator
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
Expr VisitExpr_(const AnnotateNode* op) override;

/*!
* \brief Used to visit the types inside of expressions.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def schedule_batch_matmul(attrs, outputs, target):
with target:
return topi.generic.schedule_batch_matmul(outputs)

reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OPAQUE)


# conv2d
Expand Down
15 changes: 13 additions & 2 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ TVM_REGISTER_API("relay._make.Call")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ")";
});

Let LetNode::make(Var var, Expr value, Expr body) {
Expand Down Expand Up @@ -349,5 +348,17 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
*ret = temp->Realize();
});

Annotate AnnotateNode::make(Expr expr, NodeRef annotation) {
NodePtr<AnnotateNode> n = make_node<AnnotateNode>();
n->expr = std::move(expr);
n->annotation = std::move(annotation);
return Annotate(n);
}

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<AnnotateNode>([](const AnnotateNode* node, tvm::IRPrinter* p) {
p->stream << "AnnotateNode(" << node->expr << ")";
});

} // namespace relay
} // namespace tvm
8 changes: 8 additions & 0 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }

Type ExprMutator::VisitType(const Type& t) { return t; }

Expr ExprMutator::VisitExpr_(const AnnotateNode* op) {
return AnnotateNode::make(VisitExpr(op->expr), op->annotation);
}

void ExprVisitor::VisitExpr(const Expr& expr) {
auto it = visit_counter_.find(expr.get());
if (it != visit_counter_.end()) {
Expand Down Expand Up @@ -315,6 +319,10 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) {
}
}

void ExprVisitor::VisitExpr_(const AnnotateNode* op) {
this->VisitExpr(op->expr);
}

void ExprVisitor::VisitClause(const Clause& op) {
this->VisitPattern(op->lhs);
this->VisitExpr(op->rhs);
Expand Down
Loading