Skip to content

Commit 9697d5e

Browse files
Copilotjason810496
andauthored
Add script to generate missing async methods for PGMQOperation (#37)
* Initial plan * Add compelete_missing_async_methods_for_operation.py script and operation_ast helper Co-authored-by: jason810496 <[email protected]> * Fix path resolution bug in operation_ast.py Co-authored-by: jason810496 <[email protected]> * Address code review feedback: improve docstring handling and remove unnecessary method Co-authored-by: jason810496 <[email protected]> * Improve docstring handling robustness with better length checks Co-authored-by: jason810496 <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: jason810496 <[email protected]>
1 parent 5b7b999 commit 9697d5e

2 files changed

Lines changed: 260 additions & 0 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python
2+
# /// script
3+
# requires-python = ">=3.10,<3.11"
4+
# dependencies = [
5+
# "rich>=13.6.0",
6+
# "libcst>=1.0.0",
7+
# ]
8+
# ///
9+
"""
10+
Script to check for missing async methods in PGMQOperation class and generate them.
11+
12+
For each public sync method (not starting with _), checks if there's a corresponding
13+
async method with the same name plus '_async' suffix. If missing, generates it.
14+
"""
15+
16+
import libcst as cst
17+
import sys
18+
from pathlib import Path
19+
import contextlib
20+
import shutil
21+
22+
import tempfile
23+
24+
25+
from scripts_utils.config import OPERATION_FILE, OPERATION_BACKUP_FILE
26+
from scripts_utils.console import console, user_input
27+
from scripts_utils.common_ast import (
28+
parse_methods_info_from_target_class,
29+
fill_missing_methods_to_class,
30+
)
31+
from scripts_utils.formatting import format_file, compare_file
32+
from scripts_utils.operation_ast import get_async_methods_to_add
33+
34+
35+
def main():
36+
"""Main function."""
37+
38+
module_tree = cst.parse_module(OPERATION_FILE.read_text())
39+
sync_methods, missing_async = parse_methods_info_from_target_class(
40+
module_tree, target_class="PGMQOperation"
41+
)
42+
43+
if not missing_async:
44+
console.print(
45+
"[bold green]SUCCESS:[/bold green] All public methods have corresponding async versions!"
46+
)
47+
sys.exit(0)
48+
49+
# log all the missing async methods
50+
console.print()
51+
console.print(
52+
f"[bold yellow]WARNING:[/bold yellow] Found {len(missing_async)} missing async methods:",
53+
style="bold",
54+
)
55+
for method in missing_async:
56+
console.print(f" [yellow]-[/yellow] {method}_async")
57+
console.print()
58+
59+
# create missing async method from
60+
async_methods_to_add = get_async_methods_to_add(sync_methods, missing_async)
61+
# insert back to class
62+
module_tree = fill_missing_methods_to_class(
63+
module_tree, "PGMQOperation", async_methods_to_add
64+
)
65+
66+
# write back to tmp file for comparison
67+
tmp_file = ""
68+
with tempfile.NamedTemporaryFile(mode="w+t", delete=False, suffix=".py") as f:
69+
f.write(module_tree.code)
70+
f.flush()
71+
tmp_file = f.name
72+
console.log(f"Complete missing async methods at {tmp_file}")
73+
74+
if tmp_file:
75+
max_formatting = 3
76+
for _ in range(max_formatting):
77+
if format_file(tmp_file):
78+
break
79+
80+
_, missing_async_for_tmp = parse_methods_info_from_target_class(
81+
cst.parse_module(Path(tmp_file).read_text()), "PGMQOperation"
82+
)
83+
84+
if missing_async_for_tmp:
85+
console.log(
86+
f"[error]Still get async methods to add after generating missing async methods in {tmp_file}: {missing_async_for_tmp}[/]"
87+
)
88+
else:
89+
console.log("[success]All missing async methods are generated[/]")
90+
91+
# compare existed operation.py and tmp.py
92+
with contextlib.suppress(Exception):
93+
compare_file(OPERATION_FILE, tmp_file)
94+
95+
# ask whether to apply the change
96+
if user_input(f"Do you want to apply change to {OPERATION_FILE}"):
97+
console.log(f"Backup existed {OPERATION_FILE} at {OPERATION_BACKUP_FILE}")
98+
shutil.copy(OPERATION_FILE, OPERATION_BACKUP_FILE)
99+
shutil.copy(tmp_file, OPERATION_FILE)
100+
console.log("Add missing async methods successfully")
101+
102+
sys.exit(0)
103+
104+
105+
if __name__ == "__main__":
106+
main()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import libcst as cst
2+
import re
3+
import sys
4+
from pathlib import Path
5+
from typing import List, Set, Dict
6+
7+
sys.path.insert(0, str(Path(__file__).parent.parent.joinpath("scripts").resolve()))
8+
9+
from scripts_utils.common_ast import MethodInfo # noqa: E402
10+
11+
12+
class AsyncOperationTransformer(cst.CSTTransformer):
13+
"""Transform sync PGMQOperation methods to async versions."""
14+
15+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
16+
"""Transform session.execute() and session.commit() calls to await."""
17+
# Check if this is a session.execute() or session.commit() call
18+
if isinstance(updated_node.func, cst.Attribute):
19+
if isinstance(updated_node.func.value, cst.Name):
20+
if updated_node.func.value.value == "session":
21+
if updated_node.func.attr.value in ["execute", "commit"]:
22+
# Wrap in await
23+
return cst.Await(expression=updated_node)
24+
25+
return updated_node
26+
27+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
28+
"""Transform function to async version."""
29+
# Transform function to async
30+
new_node = updated_node.with_changes(
31+
asynchronous=cst.Asynchronous(),
32+
name=cst.Name(f"{updated_node.name.value}_async")
33+
)
34+
35+
# Update parameters - change Session to AsyncSession
36+
if updated_node.params:
37+
new_params = self._transform_params(updated_node.params)
38+
new_node = new_node.with_changes(params=new_params)
39+
40+
# Transform docstring if exists
41+
if updated_node.body.body and isinstance(updated_node.body.body[0], cst.SimpleStatementLine):
42+
first_stmt = updated_node.body.body[0]
43+
if first_stmt.body and isinstance(first_stmt.body[0], cst.Expr):
44+
expr = first_stmt.body[0]
45+
if isinstance(expr.value, (cst.SimpleString, cst.ConcatenatedString)):
46+
# Extract docstring value
47+
if isinstance(expr.value, cst.SimpleString):
48+
docstring = expr.value.value
49+
else:
50+
# For concatenated strings, we'll skip transformation for now
51+
docstring = None
52+
53+
if docstring and len(docstring) >= 2:
54+
# Remove quotes to get actual string content
55+
if len(docstring) >= 6 and (docstring.startswith('"""') or docstring.startswith("'''")):
56+
quote = docstring[:3]
57+
content = docstring[3:-3]
58+
elif len(docstring) >= 2 and (docstring.startswith('"') or docstring.startswith("'")):
59+
quote = docstring[0]
60+
content = docstring[1:-1]
61+
else:
62+
content = docstring
63+
quote = '"""'
64+
65+
transformed_content = self.transform_docstring(content)
66+
new_docstring = f'{quote}{transformed_content}{quote}'
67+
68+
# Create new docstring node
69+
new_expr = expr.with_changes(value=cst.SimpleString(new_docstring))
70+
new_first_stmt = first_stmt.with_changes(body=[new_expr])
71+
72+
# Update body with new docstring
73+
new_body = [new_first_stmt] + list(updated_node.body.body[1:])
74+
new_node = new_node.with_changes(
75+
body=new_node.body.with_changes(body=new_body)
76+
)
77+
78+
return new_node
79+
80+
def _transform_params(self, params: cst.Parameters) -> cst.Parameters:
81+
"""Transform parameters - change Session to AsyncSession."""
82+
new_kwonly_params = []
83+
84+
if params.kwonly_params:
85+
for param in params.kwonly_params:
86+
if param.annotation:
87+
# Check if annotation is Session
88+
if isinstance(param.annotation.annotation, cst.Name):
89+
if param.annotation.annotation.value == "Session":
90+
# Replace Session with AsyncSession
91+
new_annotation = param.annotation.with_changes(
92+
annotation=cst.Name("AsyncSession")
93+
)
94+
new_param = param.with_changes(annotation=new_annotation)
95+
new_kwonly_params.append(new_param)
96+
continue
97+
new_kwonly_params.append(param)
98+
99+
return params.with_changes(kwonly_params=new_kwonly_params)
100+
101+
def transform_docstring(self, docstring: str) -> str:
102+
"""Transform docstring for async version."""
103+
# Replace references to sync version with async version
104+
modified = docstring
105+
106+
# Change method description to indicate it's async
107+
if "asynchronously" not in modified.lower() and "(async)" not in modified.lower():
108+
# Add async indication at the end of first sentence if not present
109+
modified = re.sub(
110+
r'(^[^.]+\.)',
111+
r'\1 (async)',
112+
modified,
113+
count=1
114+
)
115+
# Or if that didn't work, try to add it more explicitly
116+
if "(async)" not in modified:
117+
# Replace session reference
118+
modified = re.sub(
119+
r'SQLAlchemy session\.',
120+
r'Async SQLAlchemy session.',
121+
modified
122+
)
123+
modified = re.sub(
124+
r'session: SQLAlchemy session',
125+
r'session: Async SQLAlchemy session',
126+
modified
127+
)
128+
129+
return modified
130+
131+
132+
def transform_to_async_operation(
133+
transformer: AsyncOperationTransformer, method_info: MethodInfo
134+
) -> MethodInfo:
135+
"""Transform a sync method to async for PGMQOperation."""
136+
orig_sync_func_node = method_info.node
137+
async_node = orig_sync_func_node.visit(transformer)
138+
139+
return MethodInfo(f"{method_info.base_name}_async", async_node)
140+
141+
142+
def get_async_methods_to_add(
143+
sync_methods: List[MethodInfo], missing_async: Set[str]
144+
) -> Dict[str, MethodInfo]:
145+
"""Get async methods to add for missing sync methods."""
146+
transformer = AsyncOperationTransformer()
147+
async_methods: Dict[str, MethodInfo] = {}
148+
for method_info in sync_methods:
149+
if method_info.base_name in missing_async:
150+
async_methods[method_info.base_name] = transform_to_async_operation(
151+
transformer, method_info
152+
)
153+
154+
return async_methods

0 commit comments

Comments
 (0)