Skip to content

Commit 30e5f93

Browse files
lixinqiXrekihxzd5568
authored
[AP] Add paddle.cc.ap.FacadeOp: a custom op machanism for ap pass only (#72525)
* support ap.facade and infer_symbolic/infer_meta in python * minor fix * paddle.cc.ap.FacadeOp * support zero inputs for pd_op.ap_facade * Fix compiling error in CI and refine some error messages. * Correct the copyright. * Polish error messages and remove some unused header files. * Fix compiling when cinn is not enabled. * Add InferMeta to list. --------- Co-authored-by: Liu Yiqun <[email protected]> Co-authored-by: hxzd5568 <[email protected]>
1 parent c7761d7 commit 30e5f93

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2542
-98
lines changed

paddle/ap/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ cc_library(
5252
SRCS ${ap_pir_srcs}
5353
DEPS ${AP_COMMON_DEPS} ${ap_pir_deps})
5454

55+
file(GLOB_RECURSE ap_hlir_srcs "src/paddle/hlir/*.cc")
56+
set(ap_hlir_deps axpr ap_drr ap_pir)
57+
cc_library(
58+
ap_hlir
59+
SRCS ${ap_hlir_srcs}
60+
DEPS ${AP_COMMON_DEPS} ${ap_hlir_deps})
61+
5562
file(GLOB_RECURSE ap_reified_drr_srcs "src/reified_drr/*.cc")
5663
set(ap_reified_drr_deps axpr ap_drr ap_code_module ap_code_gen)
5764
cc_library(
@@ -60,7 +67,7 @@ cc_library(
6067
DEPS ${AP_COMMON_DEPS} ${ap_reified_drr_deps})
6168

6269
file(GLOB_RECURSE ap_pass_srcs "src/paddle/pass/*.cc")
63-
set(ap_pass_deps axpr ap_pir ap_drr ap_code_module ap_code_gen ap_reified_drr)
70+
set(ap_pass_deps axpr ap_hlir ap_drr ap_code_module ap_code_gen ap_reified_drr)
6471
cc_library(
6572
ap_pass
6673
SRCS ${ap_pass_srcs}

paddle/ap/include/axpr/attr_map_method_class.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,35 @@ struct AttrMapMethodClass {
5050
}
5151
};
5252

53+
template <typename ValueT>
54+
struct TypeImplBuiltinAttrMapMethodClass {
55+
using This = TypeImplBuiltinAttrMapMethodClass;
56+
using Self = TypeImpl<AttrMap<ValueT>>;
57+
58+
adt::Result<ValueT> Call(const Self&) { return &This::StaticConstruct; }
59+
60+
static adt::Result<ValueT> StaticConstruct(const ValueT&,
61+
const std::vector<ValueT>& args) {
62+
return This{}.Construct(args);
63+
}
64+
65+
adt::Result<ValueT> Construct(const std::vector<ValueT>& args) {
66+
const auto& packed_args = CastToPackedArgs(args);
67+
const auto& [pos_args, kwargs] = *packed_args;
68+
ADT_CHECK(pos_args->empty())
69+
<< adt::errors::TypeError{std::string() +
70+
"the construct of AttrMap "
71+
"takes no positional arguments."};
72+
return kwargs;
73+
}
74+
};
75+
5376
template <typename ValueT>
5477
struct MethodClassImpl<ValueT, AttrMap<ValueT>>
5578
: public AttrMapMethodClass<ValueT> {};
5679

5780
template <typename ValueT>
5881
struct MethodClassImpl<ValueT, TypeImpl<AttrMap<ValueT>>>
59-
: public EmptyMethodClass<ValueT> {};
82+
: public TypeImplBuiltinAttrMapMethodClass<ValueT> {};
6083

6184
} // namespace ap::axpr

paddle/ap/include/axpr/binary_func.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace ap::axpr {
2323
_(Sub, -) \
2424
_(Mul, *) \
2525
_(Div, /) \
26+
_(FloorDiv, /) \
2627
_(Mod, %) \
2728
_(EQ, ==) \
2829
_(NE, !=) \

paddle/ap/include/axpr/builtin_class_instance_method_class.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,58 @@ struct MethodClassImpl<ValueT, BuiltinClassInstance<ValueT>> {
114114
return class_ops->Equals(self, rhs_val);
115115
}
116116

117+
adt::Result<ValueT> Add(InterpreterBase<ValueT>* interpreter,
118+
const Self& self,
119+
const ValueT& rhs_val) {
120+
const auto& opt_func = GetClassAttr(self, "__add__");
121+
const auto& class_attrs = self.type.class_attrs();
122+
ADT_CHECK(opt_func.has_value())
123+
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
124+
" class has no attribute '__add__'"};
125+
std::vector<ValueT> args{rhs_val};
126+
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
127+
return ret;
128+
}
129+
130+
adt::Result<ValueT> Sub(InterpreterBase<ValueT>* interpreter,
131+
const Self& self,
132+
const ValueT& rhs_val) {
133+
const auto& opt_func = GetClassAttr(self, "__sub__");
134+
const auto& class_attrs = self.type.class_attrs();
135+
ADT_CHECK(opt_func.has_value())
136+
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
137+
" class has no attribute '__sub__'"};
138+
std::vector<ValueT> args{rhs_val};
139+
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
140+
return ret;
141+
}
142+
143+
adt::Result<ValueT> Mul(InterpreterBase<ValueT>* interpreter,
144+
const Self& self,
145+
const ValueT& rhs_val) {
146+
const auto& opt_func = GetClassAttr(self, "__mul__");
147+
const auto& class_attrs = self.type.class_attrs();
148+
ADT_CHECK(opt_func.has_value())
149+
<< adt::errors::AttributeError{std::string() + class_attrs->class_name +
150+
" class has no attribute '__mul__'"};
151+
std::vector<ValueT> args{rhs_val};
152+
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
153+
return ret;
154+
}
155+
156+
adt::Result<ValueT> FloorDiv(InterpreterBase<ValueT>* interpreter,
157+
const Self& self,
158+
const ValueT& rhs_val) {
159+
const auto& opt_func = GetClassAttr(self, "__floordiv__");
160+
const auto& class_attrs = self.type.class_attrs();
161+
ADT_CHECK(opt_func.has_value()) << adt::errors::AttributeError{
162+
std::string() + class_attrs->class_name +
163+
" class has no attribute '__floordiv__'"};
164+
std::vector<ValueT> args{rhs_val};
165+
ADT_LET_CONST_REF(ret, interpreter->InterpretCall(opt_func.value(), args));
166+
return ret;
167+
}
168+
117169
adt::Result<ValueT> GetItem(InterpreterBase<ValueT>* interpreter,
118170
const Self& self,
119171
const ValueT& idx_val) {

paddle/ap/include/axpr/builtin_frame_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void VisitEachBuiltinFrameAttr(const YieldT& Yield) {
4141
Yield("__builtin_not__", &BuiltinNot);
4242

4343
Yield("__builtin__foreach", &ForEach);
44+
4445
auto YieldTwice = [&](const auto& name, const auto& value) {
4546
Yield(name, value);
4647
Yield(std::string("__builtin__") + name, value);

paddle/ap/include/axpr/dim_expr_method_class.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"
2323

2424
namespace ap::axpr {
25+
template <typename ValueT>
26+
axpr::TypeImpl<axpr::BuiltinClassInstance<ValueT>> GetDimExprClass();
2527

2628
template <typename ValueT>
2729
struct DimExprMethodClass {
@@ -41,6 +43,38 @@ struct DimExprMethodClass {
4143
return hash_value;
4244
}
4345

46+
static adt::Result<ValueT> Add(const ValueT& self_val,
47+
const std::vector<ValueT>& args) {
48+
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
49+
ADT_CHECK(args.size() == 1);
50+
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
51+
return GetDimExprClass<ValueT>().New(lhs + rhs);
52+
}
53+
54+
static adt::Result<ValueT> Sub(const ValueT& self_val,
55+
const std::vector<ValueT>& args) {
56+
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
57+
ADT_CHECK(args.size() == 1);
58+
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
59+
return GetDimExprClass<ValueT>().New(lhs - rhs);
60+
}
61+
62+
static adt::Result<ValueT> Mul(const ValueT& self_val,
63+
const std::vector<ValueT>& args) {
64+
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
65+
ADT_CHECK(args.size() == 1);
66+
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
67+
return GetDimExprClass<ValueT>().New(lhs * rhs);
68+
}
69+
70+
static adt::Result<ValueT> FloorDiv(const ValueT& self_val,
71+
const std::vector<ValueT>& args) {
72+
ADT_LET_CONST_REF(lhs, self_val.template CastTo<Self>());
73+
ADT_CHECK(args.size() == 1);
74+
ADT_LET_CONST_REF(rhs, args.at(0).template CastTo<Self>());
75+
return GetDimExprClass<ValueT>().New(lhs / rhs);
76+
}
77+
4478
static adt::Result<ValueT> Match(axpr::InterpreterBase<ValueT>* interpreter,
4579
const ValueT& self_val,
4680
const std::vector<ValueT>& packed_args_val) {
@@ -93,6 +127,10 @@ axpr::TypeImpl<axpr::BuiltinClassInstance<ValueT>> GetDimExprClass() {
93127
static auto cls(
94128
axpr::MakeBuiltinClass<ValueT>("DimExpr", [&](const auto& Define) {
95129
Define("__str__", &Impl::ToString);
130+
Define("__add__", &Impl::Add);
131+
Define("__sub__", &Impl::Sub);
132+
Define("__mul__", &Impl::Mul);
133+
Define("__floordiv__", &Impl::FloorDiv);
96134
Define("__hash__", &Impl::Hash);
97135
Define("match", &Impl::Match);
98136
}));

paddle/ap/include/axpr/type_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ AttrMap<ValueT> GetObjectTypeName2Type() {
9898
OrderedDict<ValueT>,
9999
MutableOrderedDict<ValueT>,
100100
AttrMap<axpr::SerializableValue>,
101+
AttrMap<ValueT>,
101102
ValueImplTypes...>::Call(&object);
102103
return object;
103104
}

paddle/ap/include/drr/builtin_frame_util.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ void VisitEachBuiltinFrameClass(const DoEachT& DoEach) {
2727
DoEach(drr::Type<DrrCtx>{}.GetClass());
2828
}
2929

30-
template <typename VisitorT>
31-
ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap(
32-
const VisitorT& Visitor) {
30+
inline ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap() {
3331
ap::axpr::AttrMap<axpr::Value> attr_map;
3432
ap::axpr::VisitEachBuiltinFrameAttr<axpr::Value>(
3533
[&](const std::string& k, const axpr::Value& v) { attr_map->Set(k, v); });
@@ -38,7 +36,6 @@ ap::axpr::AttrMap<axpr::Value> MakeBuiltinFrameAttrMap(
3836
attr_map->Set(std::string("__builtin__") + cls.Name(), cls);
3937
};
4038
VisitEachBuiltinFrameClass(Insert);
41-
Visitor(Insert);
4239
return attr_map;
4340
}
4441

paddle/ap/include/drr/drr_interpreter.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@ namespace ap::drr {
2323

2424
class DrrInterpreter {
2525
public:
26-
explicit DrrInterpreter(
27-
const axpr::TypeImpl<axpr::BuiltinClassInstance<axpr::Value>>&
28-
backend_ir_ctx,
29-
const std::weak_ptr<ap::memory::CirclableRefListBase>&
30-
circlable_ref_list);
26+
explicit DrrInterpreter(const std::weak_ptr<ap::memory::CirclableRefListBase>&
27+
circlable_ref_list);
3128

3229
using Function = ap::axpr::Value;
3330

paddle/ap/include/paddle/pir/manual_op.h renamed to paddle/ap/include/paddle/hlir/manual_op.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/ap/include/axpr/attr_map.h"
1718
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
1819
#include "paddle/phi/core/infermeta_utils.h"
1920
#include "paddle/pir/include/core/builder.h"
@@ -25,6 +26,22 @@
2526

2627
namespace ap::dialect {
2728

29+
class IR_API FacadeOp
30+
: public pir::Op<FacadeOp, ::paddle::dialect::InferSymbolicShapeInterface> {
31+
public:
32+
using Op::Op;
33+
static const char *name() { return "ap_op.facade"; }
34+
static constexpr uint32_t attributes_num = 3;
35+
static const char *attributes_name[attributes_num];
36+
static void Build(pir::Builder &builder, // NOLINT
37+
pir::OperationArgument &argument, // NOLINT
38+
const std::vector<pir::Value> &inputs,
39+
const pir::AttributeMap &attributes,
40+
const std::vector<pir::Type> &output_types);
41+
void VerifySig() const {}
42+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
43+
};
44+
2845
class IR_API UpSpiderOp
2946
: public pir::Op<UpSpiderOp,
3047
pir::SideEffectTrait,
@@ -134,6 +151,7 @@ class IR_API StoreToGlobalOp
134151

135152
} // namespace ap::dialect
136153

154+
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::FacadeOp);
137155
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::UpSpiderOp);
138156
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::DownSpiderOp);
139157
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ap::dialect::LoadFromRegisterOp);

paddle/ap/include/paddle/meta_tensor_ptr_method_class.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ struct MetaTensorPtrMethodClass {
103103

104104
adt::Result<axpr::Value> SetDims(const Self& self,
105105
const axpr::Value& dims_val) {
106+
if (dims_val.CastableTo<DDim>()) {
107+
ADT_LET_CONST_REF(ddim, dims_val.CastTo<DDim>());
108+
return SetDimsByDDim(self, ddim);
109+
}
106110
return dims_val.Match(
107-
[&](const DDim& ddims) -> adt::Result<axpr::Value> {
108-
return SetDimsByDDim(self, ddims);
109-
},
110111
[&](const adt::List<axpr::Value>& list) -> adt::Result<axpr::Value> {
111112
return SetDimsByIntList(self, list);
112113
},
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <optional>
19+
#include "paddle/pir/include/pass/pass.h"
20+
21+
namespace ap::memory {
22+
23+
class CirclableRefListBase;
24+
25+
}
26+
27+
namespace ap::axpr {
28+
29+
struct Value;
30+
31+
}
32+
33+
namespace cinn {
34+
namespace dialect {
35+
namespace ir {
36+
37+
std::unique_ptr<::pir::Pass> CreateConvertPdFacadeToApFacadePass();
38+
39+
} // namespace ir
40+
} // namespace dialect
41+
} // namespace cinn

paddle/ap/include/paddle/pass/ir_helper_method_class.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
#include "paddle/ap/include/axpr/callable_helper.h"
1919
#include "paddle/ap/include/axpr/lambda_expr_builder.h"
2020
#include "paddle/ap/include/drr/drr_value_helper.h"
21+
#include "paddle/ap/include/paddle/hlir/op_dialect.h"
2122
#include "paddle/ap/include/paddle/pass/ap_drr_helper.h"
2223
#include "paddle/ap/include/paddle/pass/ap_generic_drr_pass.h"
2324
#include "paddle/ap/include/paddle/pass/ir_helper.h"
24-
#include "paddle/ap/include/paddle/pir/op_dialect.h"
2525
#include "paddle/ap/include/paddle/pir/packed_ir_op_inner_source_pattern_helper.h"
2626
#include "paddle/ap/include/paddle/pir/pass_manager_method_class.h"
2727
#include "paddle/ap/include/paddle/pir/pass_method_class.h"

paddle/ap/include/paddle/phi/ap_infer_meta_helper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
#pragma once
1616

1717
#include "paddle/ap/include/adt/adt.h"
18+
#include "paddle/ap/include/axpr/attr_map.h"
1819
#include "paddle/ap/include/axpr/core_expr.h"
20+
#include "paddle/ap/include/axpr/value.h"
1921
#include "paddle/phi/core/meta_tensor.h"
22+
#include "paddle/pir/include/core/operation_utils.h"
2023

2124
namespace phi {
2225

@@ -29,6 +32,12 @@ struct ApInferMetaHelper {
2932
adt::Result<adt::Ok> InferMeta(const std::string& lambda,
3033
const std::vector<const MetaTensor*>* inputs,
3134
std::vector<MetaTensor*>* outputs);
35+
36+
adt::Result<adt::Ok> InferMetaByAxprHook(
37+
const ::paddle::optional<std::vector<const MetaTensor*>>& inputs,
38+
const std::string& infer_meta_func_name,
39+
const std::string& serialized_attributes,
40+
const std::vector<MetaTensor*>& outputs);
3241
};
3342

3443
} // namespace phi

0 commit comments

Comments
 (0)