Skip to content

Commit f2968d4

Browse files
committed
Tighten async CM exemption rules
Two refinements in response to review feedback: - If an `__aenter__`/`__aexit__` method contains any checkpoint-like construct (`await`, `async with`, or `async for`), it must always checkpoint. We no longer exempt such methods even when the partner provides a checkpoint -- conditional checkpoints are still flagged. - Only charitably assume a missing partner is inherited (with a checkpoint) when the class actually inherits from something. Classes with no base classes are treated as flat, and methods that don't checkpoint are flagged. `metaclass=` and other keyword arguments do not count as inheriting, since they live in `ClassDef.keywords` rather than `ClassDef.bases`. https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG
1 parent 90d5b7b commit f2968d4

4 files changed

Lines changed: 147 additions & 22 deletions

File tree

flake8_async/visitors/visitor91x.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,15 @@ def __init__(self, *args: Any, **kwargs: Any):
467467

468468
# Tracks whether the current scope is a class body and, if so, which of
469469
# `__aenter__`/`__aexit__` are directly defined on it (values: True if
470-
# that method contains an `await`, False otherwise, missing key if not
471-
# defined). Used to exempt async context manager methods from
472-
# ASYNC910/911 when their partner method provides the checkpoint, or
473-
# when the partner is assumed inherited (not defined on this class).
470+
# that method contains a checkpoint-like construct, False otherwise,
471+
# missing key if not defined). Used to exempt async context manager
472+
# methods from ASYNC910/911 when their partner method provides the
473+
# checkpoint, or when the partner is inherited from a base class.
474474
self.async_cm_class: dict[str, bool] | None = None
475+
# Whether the enclosing class has an explicit base class (other than
476+
# implicit `object`). We only assume a missing partner is inherited if
477+
# the class actually inherits from something.
478+
self.async_cm_class_has_bases = False
475479
# Set on entry to an exempt `__aenter__`/`__aexit__` so that
476480
# `error_91x` skips emitting ASYNC910/911.
477481
self.exempt_async_cm_method = False
@@ -548,17 +552,24 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
548552
# from a base class (which we charitably assume contains a checkpoint).
549553
# See https://github.com/python-trio/flake8-async/issues/441.
550554
def visit_ClassDef(self, node: cst.ClassDef) -> None:
551-
self.save_state(node, "async_cm_class")
555+
self.save_state(node, "async_cm_class", "async_cm_class_has_bases")
552556
defined: dict[str, bool] = {}
557+
checkpointy = (
558+
m.Await()
559+
| m.With(asynchronous=m.Asynchronous())
560+
| m.For(asynchronous=m.Asynchronous())
561+
)
553562
if isinstance(node.body, cst.IndentedBlock):
554563
for stmt in node.body.body:
555564
if (
556565
isinstance(stmt, cst.FunctionDef)
557566
and stmt.asynchronous is not None
558567
and stmt.name.value in ("__aenter__", "__aexit__")
559568
):
560-
defined[stmt.name.value] = bool(m.findall(stmt, m.Await()))
569+
defined[stmt.name.value] = bool(m.findall(stmt, checkpointy))
561570
self.async_cm_class = defined
571+
# Keyword args like `metaclass=` are in `node.keywords`, not `bases`.
572+
self.async_cm_class_has_bases = bool(node.bases)
562573

563574
def leave_ClassDef(
564575
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
@@ -574,11 +585,16 @@ def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool:
574585
return False
575586
if name not in self.async_cm_class:
576587
return False
588+
# A method that contains any checkpoint must always checkpoint: we
589+
# still check it normally so conditional checkpoints are flagged.
590+
if self.async_cm_class[name]:
591+
return False
577592
partner = "__aexit__" if name == "__aenter__" else "__aenter__"
578-
# Partner not defined in this class -> assume inherited with checkpoint.
579593
if partner not in self.async_cm_class:
580-
return True
581-
# Partner defined and (charitably) contains a checkpoint.
594+
# Partner is not defined on this class; only assume it is inherited
595+
# (and contains a checkpoint) if the class inherits from something.
596+
return self.async_cm_class_has_bases
597+
# Partner defined; exempt iff it contains a checkpoint.
582598
return self.async_cm_class[partner]
583599

584600
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
@@ -609,6 +625,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
609625
"except_depth",
610626
"add_checkpoint_at_function_start",
611627
"async_cm_class",
628+
"async_cm_class_has_bases",
612629
"exempt_async_cm_method",
613630
copy=True,
614631
)
@@ -622,6 +639,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
622639
self.add_checkpoint_at_function_start = False
623640
# Class-level context does not apply to nested scopes.
624641
self.async_cm_class = None
642+
self.async_cm_class_has_bases = False
625643
self.exempt_async_cm_method = is_exempt_cm
626644

627645
self.async_function = (

tests/autofix_files/async910.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ async def bar(): ...
640640

641641
# Issue #441: async context manager methods may legitimately skip checkpointing
642642
# if the partner method provides the checkpoint, or if the partner is inherited.
643+
class ACM: # a dummy base to opt into the charitable-inheritance assumption
644+
pass
645+
646+
643647
class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast
644648
async def __aenter__(self):
645649
await foo()
@@ -676,17 +680,43 @@ async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition",
676680
# fmt: on
677681

678682

679-
# Only one method defined: charitably assume the other is inherited with a checkpoint.
680-
class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint
683+
# A method that contains any checkpoint is still required to always checkpoint.
684+
class CtxAenterConditionalAexitFast(ACM):
685+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
686+
if _:
687+
await foo()
688+
await trio.lowlevel.checkpoint()
689+
690+
async def __aexit__(self, *a):
691+
print("fast exit")
692+
693+
694+
# Only one method defined: charitably assume the other is inherited with a
695+
# checkpoint -- but only when the class inherits from something.
696+
class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited
681697
async def __aenter__(self):
682698
print("setup")
683699

684700

685-
class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint
701+
class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited
686702
async def __aexit__(self, *a):
687703
print("teardown")
688704

689705

706+
# fmt: off
707+
class CtxOnlyAenter: # no base class -> don't assume inheritance
708+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
709+
print("setup")
710+
await trio.lowlevel.checkpoint()
711+
712+
713+
class CtxOnlyAexit: # no base class -> don't assume inheritance
714+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
715+
print("teardown")
716+
await trio.lowlevel.checkpoint()
717+
# fmt: on
718+
719+
690720
class CtxOnlyAenterWithCheckpoint: # safe
691721
async def __aenter__(self):
692722
await foo()
@@ -697,22 +727,33 @@ async def __aexit__(self, *a):
697727
await foo()
698728

699729

730+
# keyword-only bases (like `metaclass=`) don't count as inheriting.
731+
class Meta(type):
732+
pass
733+
734+
735+
class CtxMetaclassOnly(metaclass=Meta):
736+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
737+
print("setup")
738+
await trio.lowlevel.checkpoint()
739+
740+
700741
# a nested function named `__aenter__` inside another function is not a method
701742
def not_a_class():
702743
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
703744
print("setup")
704745
await trio.lowlevel.checkpoint()
705746

706747

707-
# class nested inside a function still gets the exemption
748+
# class nested inside a function still gets the exemption when it inherits
708749
def factory():
709-
class NestedCtx: # safe
750+
class NestedCtx(ACM): # safe
710751
async def __aenter__(self):
711752
print("setup")
712753

713754

714755
# nested class; outer class has nothing relevant
715756
class Outer:
716-
class Inner: # safe: charitable inheritance for __aexit__
757+
class Inner(ACM): # safe: charitable inheritance for __aexit__
717758
async def __aenter__(self):
718759
print("setup")

tests/autofix_files/async910.py.diff

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,39 @@
236236

237237

238238
@@ x,6 x,7 @@
239+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
240+
if _:
241+
await foo()
242+
+ await trio.lowlevel.checkpoint()
243+
244+
async def __aexit__(self, *a):
245+
print("fast exit")
246+
@@ x,11 x,13 @@
247+
class CtxOnlyAenter: # no base class -> don't assume inheritance
248+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
249+
print("setup")
250+
+ await trio.lowlevel.checkpoint()
251+
252+
253+
class CtxOnlyAexit: # no base class -> don't assume inheritance
254+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
255+
print("teardown")
256+
+ await trio.lowlevel.checkpoint()
257+
# fmt: on
258+
259+
260+
@@ x,12 x,14 @@
261+
class CtxMetaclassOnly(metaclass=Meta):
262+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
263+
print("setup")
264+
+ await trio.lowlevel.checkpoint()
265+
266+
267+
# a nested function named `__aenter__` inside another function is not a method
239268
def not_a_class():
240269
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
241270
print("setup")
242271
+ await trio.lowlevel.checkpoint()
243272

244273

245-
# class nested inside a function still gets the exemption
274+
# class nested inside a function still gets the exemption when it inherits

tests/eval_files/async910.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ async def bar(): ...
610610

611611
# Issue #441: async context manager methods may legitimately skip checkpointing
612612
# if the partner method provides the checkpoint, or if the partner is inherited.
613+
class ACM: # a dummy base to opt into the charitable-inheritance assumption
614+
pass
615+
616+
613617
class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast
614618
async def __aenter__(self):
615619
await foo()
@@ -644,17 +648,40 @@ async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition",
644648
# fmt: on
645649

646650

647-
# Only one method defined: charitably assume the other is inherited with a checkpoint.
648-
class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint
651+
# A method that contains any checkpoint is still required to always checkpoint.
652+
class CtxAenterConditionalAexitFast(ACM):
653+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
654+
if _:
655+
await foo()
656+
657+
async def __aexit__(self, *a):
658+
print("fast exit")
659+
660+
661+
# Only one method defined: charitably assume the other is inherited with a
662+
# checkpoint -- but only when the class inherits from something.
663+
class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited
649664
async def __aenter__(self):
650665
print("setup")
651666

652667

653-
class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint
668+
class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited
654669
async def __aexit__(self, *a):
655670
print("teardown")
656671

657672

673+
# fmt: off
674+
class CtxOnlyAenter: # no base class -> don't assume inheritance
675+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
676+
print("setup")
677+
678+
679+
class CtxOnlyAexit: # no base class -> don't assume inheritance
680+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
681+
print("teardown")
682+
# fmt: on
683+
684+
658685
class CtxOnlyAenterWithCheckpoint: # safe
659686
async def __aenter__(self):
660687
await foo()
@@ -665,21 +692,31 @@ async def __aexit__(self, *a):
665692
await foo()
666693

667694

695+
# keyword-only bases (like `metaclass=`) don't count as inheriting.
696+
class Meta(type):
697+
pass
698+
699+
700+
class CtxMetaclassOnly(metaclass=Meta):
701+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
702+
print("setup")
703+
704+
668705
# a nested function named `__aenter__` inside another function is not a method
669706
def not_a_class():
670707
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
671708
print("setup")
672709

673710

674-
# class nested inside a function still gets the exemption
711+
# class nested inside a function still gets the exemption when it inherits
675712
def factory():
676-
class NestedCtx: # safe
713+
class NestedCtx(ACM): # safe
677714
async def __aenter__(self):
678715
print("setup")
679716

680717

681718
# nested class; outer class has nothing relevant
682719
class Outer:
683-
class Inner: # safe: charitable inheritance for __aexit__
720+
class Inner(ACM): # safe: charitable inheritance for __aexit__
684721
async def __aenter__(self):
685722
print("setup")

0 commit comments

Comments
 (0)