@@ -135,10 +135,6 @@ def __eq__(self, other: Any) -> bool:
135
135
"move_on_at" ,
136
136
"CancelScope" ,
137
137
)
138
- context_manager_names = (
139
- "contextmanager" ,
140
- "asynccontextmanager" ,
141
- )
142
138
143
139
144
140
class Flake8TrioVisitor (ast .NodeVisitor ):
@@ -312,8 +308,7 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
312
308
outer = self .get_state ()
313
309
self .set_state (self .defaults , copy = True )
314
310
315
- # check for @<context_manager_name> and @<library>.<context_manager_name>
316
- if has_decorator (node .decorator_list , * context_manager_names ):
311
+ if has_decorator (node .decorator_list , "contextmanager" , "asynccontextmanager" ):
317
312
self ._safe_decorator = True
318
313
319
314
self .generic_visit (node )
@@ -509,21 +504,16 @@ def __init__(self, *args: Any, **kwargs: Any):
509
504
super ().__init__ (* args , ** kwargs )
510
505
self ._critical_scope : Optional [Statement ] = None
511
506
self ._trio_context_managers : List [Visitor102 .TrioScope ] = []
512
- self ._safe_decorator = False
513
507
514
- # if we're inside a finally, and not inside a context_manager, and we're not
515
- # inside a scope that doesn't have both a timeout and shield
508
+ # if we're inside a finally, and we're not inside a scope that doesn't have
509
+ # both a timeout and shield
516
510
def visit_Await (
517
511
self ,
518
512
node : Union [ast .Await , ast .AsyncFor , ast .AsyncWith ],
519
513
visit_children : bool = True ,
520
514
):
521
- if (
522
- self ._critical_scope is not None
523
- and not self ._safe_decorator
524
- and not any (
525
- cm .has_timeout and cm .shielded for cm in self ._trio_context_managers
526
- )
515
+ if self ._critical_scope is not None and not any (
516
+ cm .has_timeout and cm .shielded for cm in self ._trio_context_managers
527
517
):
528
518
self .error ("TRIO102" , node , self ._critical_scope )
529
519
if visit_children :
@@ -560,19 +550,6 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
560
550
self .visit_Await (node , visit_children = False )
561
551
self .visit_With (node )
562
552
563
- def visit_FunctionDef (self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]):
564
- outer = self .get_state ("_safe_decorator" )
565
-
566
- # check for @<context_manager_name> and @<library>.<context_manager_name>
567
- if has_decorator (node .decorator_list , * context_manager_names ):
568
- self ._safe_decorator = True
569
-
570
- self .generic_visit (node )
571
-
572
- self .set_state (outer )
573
-
574
- visit_AsyncFunctionDef = visit_FunctionDef
575
-
576
553
def critical_visit (
577
554
self ,
578
555
node : Union [ast .ExceptHandler , Iterable [ast .AST ]],
0 commit comments