Skip to content

Commit 7db69c4

Browse files
committed
Fixes and more tests
1 parent 5eb25ba commit 7db69c4

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

flake8_trio.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def __lt__(self, other: Any) -> bool:
127127
def __eq__(self, other: Any) -> bool:
128128
return isinstance(other, Error) and self.cmp() == other.cmp()
129129

130+
def __repr__(self) -> str:
131+
trailer = "".join(f", {x!r}" for x in self.args)
132+
return f"<{self.code} error at {self.line}:{self.col}{trailer}>"
133+
130134

131135
checkpoint_node_types = (ast.Await, ast.AsyncFor, ast.AsyncWith)
132136
cancel_scope_names = (
@@ -806,12 +810,13 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
806810
)
807811

808812

809-
def is_nursery_call(node: ast.AST, name: str) -> bool:
813+
def is_nursery_call(node: ast.AST, name: str) -> bool: # pragma: no cover
810814
assert name in ("start", "start_soon")
811815
if isinstance(node, ast.Attribute):
812-
if not isinstance(node.value, ast.Name):
813-
return is_nursery_call(node.value, name) # might be self.nursery.start()
814-
return node.value.id.endswith("nursery") and node.attr == "start"
816+
if isinstance(node.value, ast.Name):
817+
return node.attr == name and node.value.id.endswith("nursery")
818+
if isinstance(node.value, ast.Attribute):
819+
return node.attr == name and node.value.attr.endswith("nursery")
815820
return False
816821

817822

@@ -827,8 +832,17 @@ def visit(self, node: ast.AST):
827832
self.node_stack.pop()
828833

829834
def visit_Call(self, node: ast.Call):
830-
if is_nursery_call(node.func, "start") and (
831-
len(self.node_stack) < 2 or not isinstance(self.node_stack[-2], ast.Await)
835+
if (
836+
isinstance(node.func, ast.Attribute)
837+
and isinstance(node.func.value, ast.Name)
838+
and (
839+
(node.func.value.id == "trio" and node.func.attr in trio_async_funcs)
840+
or is_nursery_call(node.func, "start")
841+
)
842+
and (
843+
len(self.node_stack) < 2
844+
or not isinstance(self.node_stack[-2], ast.Await)
845+
)
832846
):
833847
assert isinstance(node.func, ast.Attribute)
834848
self.error("TRIO105", node, node.func.attr)
@@ -1179,7 +1193,7 @@ def __init__(self, *args: Any, **kwargs: Any):
11791193
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
11801194
outer = self.aenter
11811195

1182-
self.aenter = (node.name == "__aenter__" and len(node.args.args) == 1) or any(
1196+
self.aenter = node.name == "__aenter__" or any(
11831197
_get_identifier(d) == "asynccontextmanager" for d in node.decorator_list
11841198
)
11851199

tests/test_flake8_trio.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def read_file(test_file: str):
159159

160160
def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Error):
161161
# initialize default option values
162-
om = OptionManager(version="", plugin_versions="", parents=[])
162+
om = OptionManager(
163+
version="",
164+
plugin_versions="",
165+
parents=[],
166+
formatter_names=["default"], # type: ignore
167+
)
163168
plugin.add_options(om)
164169
plugin.parse_options(om.parse_args(args=[""]))
165170

@@ -361,16 +366,21 @@ def test_107_permutations():
361366
def test_113_options():
362367
# check that no errors are given by default
363368
plugin = read_file("trio113.py")
364-
om = OptionManager(version="", plugin_versions="", parents=[])
369+
om = OptionManager(
370+
version="",
371+
plugin_versions="",
372+
parents=[],
373+
formatter_names=["default"], # type: ignore
374+
)
365375
plugin.add_options(om)
366376
plugin.parse_options(om.parse_args(args=["--startable-in-context-manager=''"]))
367-
assert not sorted(e for e in plugin.run() if e.code == "TRIO113")
377+
default = {repr(e) for e in plugin.run() if e.code == "TRIO113"}
368378

369379
# and that the expected errors are given if we empty it and then extend it
370-
arg = "--startable-in-context-manager='custom_startable_function'"
380+
arg = "--startable-in-context-manager=custom_startable_function"
371381
plugin.parse_options(om.parse_args(args=[arg]))
372-
errors = sorted(e for e in plugin.run() if e.code == "TRIO113")
373-
assert errors == [Error("TRIO113", 58, 8)]
382+
errors = {repr(e) for e in plugin.run() if e.code == "TRIO113"} - default
383+
assert errors == {repr(Error("TRIO113", 58, 8))}
374384

375385

376386
@pytest.mark.fuzz

tests/trio113.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ async def __aenter__(_a_parameter_not_named_self):
6666
# might be monkeypatched onto an instance, count this as an error too
6767
async def __aenter__():
6868
nursery.start_soon(trio.run_process) # error: 4
69+
nursery.start_soon() # broken code, but our analysis shouldn't crash
70+
nursery.cancel_scope.cancel()
6971

7072

7173
# this only takes a single parameter ... right? :P

0 commit comments

Comments
 (0)