Skip to content

Commit 90d5b7b

Browse files
committed
Exempt async CM methods from ASYNC910/911 when partner checkpoints
ASYNC910 and ASYNC911 no longer require every `__aenter__`/`__aexit__` to contain a checkpoint. Per Trio's documentation, an async context manager only needs one of entry/exit to act as a checkpoint. When a class defines both methods, the one without an `await` is exempt if its partner contains one. When a class defines only one of the two, the partner is charitably assumed to be inherited from a base class and to contain a checkpoint, so the defined method is also exempt. Closes #441 https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG
1 parent 9a49703 commit 90d5b7b

5 files changed

Lines changed: 240 additions & 0 deletions

File tree

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Changelog
77
Unreleased
88
==========
99
- Autofix for :ref:`ASYNC910 <async910>` / :ref:`ASYNC911 <async911>` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 <async120>`); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) <https://github.com/python-trio/flake8-async/issues/403>`_
10+
- :ref:`ASYNC910 <async910>` and :ref:`ASYNC911 <async911>` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) <https://github.com/python-trio/flake8-async/issues/441>`_
1011

1112
25.7.1
1213
======

flake8_async/visitors/visitor91x.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,17 @@ def __init__(self, *args: Any, **kwargs: Any):
465465
# used to transfer new body between visit_FunctionDef and leave_FunctionDef
466466
self.new_body: cst.BaseSuite | None = None
467467

468+
# Tracks whether the current scope is a class body and, if so, which of
469+
# `__aenter__`/`__aexit__` are directly defined on it (values: True if
470+
# that method contains an `await`, False otherwise, missing key if not
471+
# defined). Used to exempt async context manager methods from
472+
# ASYNC910/911 when their partner method provides the checkpoint, or
473+
# when the partner is assumed inherited (not defined on this class).
474+
self.async_cm_class: dict[str, bool] | None = None
475+
# Set on entry to an exempt `__aenter__`/`__aexit__` so that
476+
# `error_91x` skips emitting ASYNC910/911.
477+
self.exempt_async_cm_method = False
478+
468479
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
469480
if code is None:
470481
code = "ASYNC911" if self.has_yield else "ASYNC910"
@@ -532,6 +543,44 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
532543
self.suppress_imported_as.append("suppress")
533544
return
534545

546+
# Async context manager methods may legitimately skip checkpointing if the
547+
# partner method provides the checkpoint, or if the partner is inherited
548+
# from a base class (which we charitably assume contains a checkpoint).
549+
# See https://github.com/python-trio/flake8-async/issues/441.
550+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
551+
self.save_state(node, "async_cm_class")
552+
defined: dict[str, bool] = {}
553+
if isinstance(node.body, cst.IndentedBlock):
554+
for stmt in node.body.body:
555+
if (
556+
isinstance(stmt, cst.FunctionDef)
557+
and stmt.asynchronous is not None
558+
and stmt.name.value in ("__aenter__", "__aexit__")
559+
):
560+
defined[stmt.name.value] = bool(m.findall(stmt, m.Await()))
561+
self.async_cm_class = defined
562+
563+
def leave_ClassDef(
564+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
565+
) -> cst.ClassDef:
566+
self.restore_state(original_node)
567+
return updated_node
568+
569+
def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool:
570+
if self.async_cm_class is None:
571+
return False
572+
name = node.name.value
573+
if name not in ("__aenter__", "__aexit__"):
574+
return False
575+
if name not in self.async_cm_class:
576+
return False
577+
partner = "__aexit__" if name == "__aenter__" else "__aenter__"
578+
# Partner not defined in this class -> assume inherited with checkpoint.
579+
if partner not in self.async_cm_class:
580+
return True
581+
# Partner defined and (charitably) contains a checkpoint.
582+
return self.async_cm_class[partner]
583+
535584
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
536585
# `await` in default values happen in parent scope
537586
# we also know we don't ever modify parameters so we can ignore the return value
@@ -543,6 +592,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
543592
if func_has_decorator(node, "overload", "fixture") or func_empty_body(node):
544593
return False # subnodes can be ignored
545594

595+
is_exempt_cm = self._is_exempt_async_cm_method(node)
596+
546597
self.save_state(
547598
node,
548599
"has_yield",
@@ -557,6 +608,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
557608
"suppress_imported_as", # a copy is saved, but state is not reset
558609
"except_depth",
559610
"add_checkpoint_at_function_start",
611+
"async_cm_class",
612+
"exempt_async_cm_method",
560613
copy=True,
561614
)
562615
self.uncheckpointed_statements = set()
@@ -567,6 +620,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
567620
self.taskgroup_has_start_soon = {}
568621
self.except_depth = 0
569622
self.add_checkpoint_at_function_start = False
623+
# Class-level context does not apply to nested scopes.
624+
self.async_cm_class = None
625+
self.exempt_async_cm_method = is_exempt_cm
570626

571627
self.async_function = (
572628
node.asynchronous is not None
@@ -747,6 +803,12 @@ def error_91x(
747803
) -> bool:
748804
assert not isinstance(statement, ArtificialStatement), statement
749805

806+
# Exempt `__aenter__`/`__aexit__` when the partner method contains a
807+
# checkpoint, or when the partner is missing and charitably assumed
808+
# inherited.
809+
if self.exempt_async_cm_method:
810+
return False
811+
750812
if isinstance(node, cst.FunctionDef):
751813
msg = "exit"
752814
else:

tests/autofix_files/async910.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,83 @@ async def foo_nested_empty_async():
636636
async def bar(): ...
637637

638638
await foo()
639+
640+
641+
# Issue #441: async context manager methods may legitimately skip checkpointing
642+
# if the partner method provides the checkpoint, or if the partner is inherited.
643+
class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast
644+
async def __aenter__(self):
645+
await foo()
646+
647+
async def __aexit__(self, exc_type, exc, tb):
648+
print("fast exit")
649+
650+
651+
class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast
652+
async def __aenter__(self):
653+
print("fast setup")
654+
655+
async def __aexit__(self, exc_type, exc, tb):
656+
await foo()
657+
658+
659+
class CtxWithBothCheckpoint: # safe: both checkpoint
660+
async def __aenter__(self):
661+
await foo()
662+
663+
async def __aexit__(self, exc_type, exc, tb):
664+
await foo()
665+
666+
667+
# fmt: off
668+
class CtxNeitherCheckpoint:
669+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
670+
print("setup")
671+
await trio.lowlevel.checkpoint()
672+
673+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
674+
print("teardown")
675+
await trio.lowlevel.checkpoint()
676+
# fmt: on
677+
678+
679+
# Only one method defined: charitably assume the other is inherited with a checkpoint.
680+
class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint
681+
async def __aenter__(self):
682+
print("setup")
683+
684+
685+
class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint
686+
async def __aexit__(self, *a):
687+
print("teardown")
688+
689+
690+
class CtxOnlyAenterWithCheckpoint: # safe
691+
async def __aenter__(self):
692+
await foo()
693+
694+
695+
class CtxOnlyAexitWithCheckpoint: # safe
696+
async def __aexit__(self, *a):
697+
await foo()
698+
699+
700+
# a nested function named `__aenter__` inside another function is not a method
701+
def not_a_class():
702+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
703+
print("setup")
704+
await trio.lowlevel.checkpoint()
705+
706+
707+
# class nested inside a function still gets the exemption
708+
def factory():
709+
class NestedCtx: # safe
710+
async def __aenter__(self):
711+
print("setup")
712+
713+
714+
# nested class; outer class has nothing relevant
715+
class Outer:
716+
class Inner: # safe: charitable inheritance for __aexit__
717+
async def __aenter__(self):
718+
print("setup")

tests/autofix_files/async910.py.diff

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,23 @@
223223

224224

225225
async def foo_nested_empty_async():
226+
@@ x,9 x,11 @@
227+
class CtxNeitherCheckpoint:
228+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
229+
print("setup")
230+
+ await trio.lowlevel.checkpoint()
231+
232+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
233+
print("teardown")
234+
+ await trio.lowlevel.checkpoint()
235+
# fmt: on
236+
237+
238+
@@ x,6 x,7 @@
239+
def not_a_class():
240+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
241+
print("setup")
242+
+ await trio.lowlevel.checkpoint()
243+
244+
245+
# class nested inside a function still gets the exemption

tests/eval_files/async910.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,80 @@ async def foo_nested_empty_async():
606606
async def bar(): ...
607607

608608
await foo()
609+
610+
611+
# Issue #441: async context manager methods may legitimately skip checkpointing
612+
# if the partner method provides the checkpoint, or if the partner is inherited.
613+
class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast
614+
async def __aenter__(self):
615+
await foo()
616+
617+
async def __aexit__(self, exc_type, exc, tb):
618+
print("fast exit")
619+
620+
621+
class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast
622+
async def __aenter__(self):
623+
print("fast setup")
624+
625+
async def __aexit__(self, exc_type, exc, tb):
626+
await foo()
627+
628+
629+
class CtxWithBothCheckpoint: # safe: both checkpoint
630+
async def __aenter__(self):
631+
await foo()
632+
633+
async def __aexit__(self, exc_type, exc, tb):
634+
await foo()
635+
636+
637+
# fmt: off
638+
class CtxNeitherCheckpoint:
639+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
640+
print("setup")
641+
642+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
643+
print("teardown")
644+
# fmt: on
645+
646+
647+
# Only one method defined: charitably assume the other is inherited with a checkpoint.
648+
class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint
649+
async def __aenter__(self):
650+
print("setup")
651+
652+
653+
class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint
654+
async def __aexit__(self, *a):
655+
print("teardown")
656+
657+
658+
class CtxOnlyAenterWithCheckpoint: # safe
659+
async def __aenter__(self):
660+
await foo()
661+
662+
663+
class CtxOnlyAexitWithCheckpoint: # safe
664+
async def __aexit__(self, *a):
665+
await foo()
666+
667+
668+
# a nested function named `__aenter__` inside another function is not a method
669+
def not_a_class():
670+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
671+
print("setup")
672+
673+
674+
# class nested inside a function still gets the exemption
675+
def factory():
676+
class NestedCtx: # safe
677+
async def __aenter__(self):
678+
print("setup")
679+
680+
681+
# nested class; outer class has nothing relevant
682+
class Outer:
683+
class Inner: # safe: charitable inheritance for __aexit__
684+
async def __aenter__(self):
685+
print("setup")

0 commit comments

Comments
 (0)