-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathop_registration_util.bzl
139 lines (124 loc) · 5.1 KB
/
op_registration_util.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
load(
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
"get_vec_deps",
"get_vec_preprocessor_flags",
)
load(
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
"get_compiler_optimization_flags",
)
def op_target(name, deps = [], compiler_flags = []):
"""Registers an optimized implementation for an operator overload group.
An operator overload group is a set of operator overloads with a common
operator name. That common operator name should be the base name of this
target.
E.g., the "add" operator overload group, named "op_add" in this target,
might implement:
- add.Tensor
- add_.Tensor
- add.out
- add.Scalar
If an op target would like to share a header/sources with a different op
target (e.g., helpers/utilities), it should declare a separate cxx_library
and add it as a dep.
Args:
name: The name of the operator overload group; e.g.,
"op_add". This directory must contain a source file named
"<name>.cpp"; e.g., "op_add.cpp".
deps: Optional extra deps to add to the cxx_library(). Note:
- op targets may not depend on other op targets, to keep the
dependencies manageable. If two op targets would like to share
code, define a separate runtime.cxx_library that they both depend
on.
compiler_flags: Optional compiler flags to add to the cxx_library().
"""
# Note that this doesn't actually define the target, but helps register
# it in a table that's used to define the target.
return {
"compiler_flags": compiler_flags,
"deps": deps,
"name": name,
}
def _enforce_deps(deps, name):
"""Fails if any of the deps are not allowed.
Args:
deps: A list of build target strings.
name: The name of the target; e.g., "op_add"
"""
for dep in deps:
if dep.startswith(":op_"):
# op targets may not depend on other op targets, to keep the
# dependencies manageable. If two op targets would like to share
# code, define a separate runtime.cxx_library that they both depend
# on.
fail("op_target {} may not depend on other op_target {}".format(
name,
dep,
))
def define_op_library(name, compiler_flags, deps):
"""Defines a cxx_library target for the named operator overload group.
Args:
name: The name of the target; e.g., "op_add"
deps: List of deps for the target.
"""
selects.apply(obj = deps, function = native.partial(_enforce_deps, name = name))
augmented_deps = deps + [
"//executorch/kernels/optimized:libvec",
"//executorch/kernels/optimized:libutils",
]
runtime.cxx_library(
name = "{}".format(name),
srcs = [
"{}.cpp".format(name),
],
visibility = [
"//executorch/kernels/portable/test/...",
"//executorch/kernels/quantized/test/...",
"//executorch/kernels/optimized/test/...",
"//executorch/kernels/test/...",
"@EXECUTORCH_CLIENTS",
],
compiler_flags = [
# kernels often have helpers with no prototypes just disabling the warning here as the headers
# are codegend and linked in later
"-Wno-missing-prototypes",
# pragma unroll fails with -Os, don't need to warn us and
# fail Werror builds; see https://godbolt.org/z/zvf85vTsr
"-Wno-pass-failed",
] + compiler_flags + get_compiler_optimization_flags(),
deps = [
"//executorch/runtime/kernel:kernel_includes",
] + augmented_deps + get_vec_deps(),
preprocessor_flags = get_vec_preprocessor_flags(),
# sleef needs to be added as a direct dependency of the operator target when building for Android,
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
# dependencies are not transitive
fbandroid_platform_deps = [
(
"^android-arm64.*$",
[
"fbsource//third-party/sleef:sleef_arm",
],
),
],
# link_whole is necessary because the operators register themselves
# via static initializers that run at program startup.
# @lint-ignore BUCKLINT link_whole
link_whole = True,
)
def define_op_target(name, compiler_flags, deps):
"""Possibly defines cxx_library targets for the named operator group.
Args:
name: The base name of the target; e.g., "op_add"
deps: List of deps for the targets.
"""
# When building in ATen mode, ATen-compatible (non-custom) operators will
# use the implementations provided by ATen, so we should not build the
# versions defined here.
define_op_library(
name = name,
compiler_flags = compiler_flags,
deps = deps,
)