Skip to content

Commit 1e97232

Browse files
Arm backend: Add pre-push checks for op tests (#9899)
- Check that @SkipIfNoCorstone is not used - Check that @expectedFailureOnFVP is not used - Check that unittest.TestCase is not used - Check that on_fvp suffix is not used - Check that op and target is parsed from the test name Signed-off-by: Adrian Lundell <[email protected]>
1 parent db21501 commit 1e97232

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed
+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT
6+
7+
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
8+
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
9+
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
10+
11+
# Add all targets and TOSA profiles we support here.
12+
TARGETS = {"tosa_BI", "tosa_MI", "u55_BI", "u85_BI"}
13+
14+
15+
def get_edge_ops():
16+
"""
17+
Returns a set with edge_ops with names on the form to be used in unittests:
18+
1. Names are in lowercase.
19+
2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
20+
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
21+
22+
Examples:
23+
abs.default -> abs
24+
split_copy.Tensor -> split_tensor
25+
"""
26+
edge_ops = set()
27+
for edge_name in ALL_EDGE_OPS:
28+
op, overload = edge_name.split(".")
29+
30+
# Normalize names
31+
op = op.lower()
32+
op = op.removeprefix("_")
33+
op = op.removesuffix("_copy")
34+
op = op.removesuffix("_with_indices")
35+
op = op.removesuffix("_no_training")
36+
overload = overload.lower()
37+
38+
if overload == "default":
39+
edge_ops.add(op)
40+
else:
41+
edge_ops.add(f"{op}_{overload}")
42+
43+
return edge_ops
44+
45+
46+
def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]:
47+
"""
48+
Parses a test name on the form
49+
test_OP_TARGET_<not_delegated>_<any_other_info>
50+
where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
51+
The "not_delegated" suffix indicates that the test tests that the op is not delegated.
52+
53+
Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
54+
55+
Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
56+
"""
57+
test_name = test_name.removeprefix("test_")
58+
is_delegated = "not_delegated" not in test_name
59+
assert (
60+
"reject" not in test_name
61+
), f"Use 'not_delegated' instead of 'reject' in {test_name}"
62+
63+
op = "None"
64+
target = "None"
65+
for potential_target in TARGETS:
66+
index = test_name.find(potential_target)
67+
if index != -1:
68+
op = test_name[: index - 1]
69+
target = potential_target
70+
break
71+
# Special case for convolution
72+
op = op.removesuffix("_1d")
73+
op = op.removesuffix("_2d")
74+
75+
assert target != "None", f"{test_name} does not contain one of {TARGETS}"
76+
assert (
77+
op in edge_ops
78+
), f"Parsed unvalid OP from {test_name}, {op} does not exist in edge.yaml or CUSTOM_EDGE_OPS"
79+
80+
return op, target, is_delegated
81+
82+
83+
if __name__ == "__main__":
84+
"""Parses a list of test names given on the commandline."""
85+
import sys
86+
87+
sys.tracebacklimit = 0 # Do not print stack trace
88+
89+
edge_ops = get_edge_ops()
90+
exit_code = 0
91+
92+
for test_name in sys.argv[1:]:
93+
try:
94+
assert test_name[:5] == "test_", f"Unexpected input: {test_name}"
95+
parse_test_name(test_name, edge_ops)
96+
except AssertionError as e:
97+
print(e)
98+
exit_code = 1
99+
else:
100+
print(f"{test_name} OK")
101+
102+
sys.exit(exit_code)

backends/arm/scripts/pre-push

+38
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,44 @@ for COMMIT in ${COMMITS}; do
166166
fi
167167
fi
168168

169+
# Op test checks
170+
op_test_files=$(echo $commit_files | grep -oE 'backends/arm/test/ops/\S+')
171+
if [ "$op_test_files" ]; then
172+
173+
# TODO: These checks can be removed when all unittests are refactored.
174+
if grep -icq "SkipIfNoCorstone" $op_test_files; then
175+
echo -e "${ERROR} @SkipIfNoCorstone300/320 is deprecated;"\
176+
"please use XfailIfNoCorstone300/320 instead." >&2
177+
FAILED=1
178+
fi
179+
180+
if grep -icq "conftest.expectedFailureOnFVP" $op_test_files; then
181+
echo -e "${ERROR} @conftest.expectedFailureOnFVP is deprecated;"\
182+
"please use XfailIfCorstone300/320 instead." >&2
183+
FAILED=1
184+
fi
185+
186+
if grep -icq "unittest.TestCase" $op_test_files; then
187+
echo -e "${ERROR} Use of the Unittest test framework is deprecated;"\
188+
"please use Pytest instead." >&2
189+
FAILED=1
190+
fi
191+
192+
if grep -icq "on_fvp(" $op_test_files; then
193+
echo -e "${ERROR} All unittests should run on FVP if relevant,"\
194+
"on_fvp suffix can be excluded." >&2
195+
FAILED=1
196+
fi
197+
198+
# Check that the tested op and target is parsed correctly from the test name
199+
test_names=$(grep -h "def test_" $op_test_files | cut -d"(" -f1 | cut -d" " -f2)
200+
python ./backends/arm/scripts/parse_test_names.py $test_names
201+
if [ $? -ne 0 ]; then
202+
echo -e "${ERROR} Failed op test name check." >&2
203+
FAILED=1
204+
fi
205+
fi
206+
169207
echo "" # Newline to visually separate commit processing
170208
done
171209

0 commit comments

Comments
 (0)