Skip to content

Commit d6fc36e

Browse files
committed
Add pre-commit hooks and scripts for async method checks in PGMQueue
- Introduced a pre-commit hook to check for missing async methods in PGMQueue. - Added scripts to identify and generate missing async methods. - Created utility functions for AST manipulation and method transformation. - Established configuration for project paths and console output.
1 parent f9feea8 commit d6fc36e

9 files changed

Lines changed: 427 additions & 1 deletion

File tree

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,10 @@ repos:
77
- id: ruff
88
args: [ --fix ]
99
# Run the formatter.
10-
- id: ruff-format
10+
- id: ruff-format
11+
- repo: local
12+
hooks:
13+
- id: check-sync-async-method-for-queue
14+
name: Check sync/async method for queue
15+
entry: ./scripts/ci/pre_commit/check_sync_async_method_for_queue.py
16+
language: python
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python
2+
# /// script
3+
# requires-python = ">=3.10,<3.11"
4+
# dependencies = [
5+
# "rich>=13.6.0",
6+
# ]
7+
# ///
8+
"""
9+
Script to check for missing async methods in PGMQueue for per-commit.
10+
11+
For each public sync method (not starting with _), checks if there's a corresponding
12+
async method with the same name plus '_async' suffix.
13+
"""
14+
15+
import ast
16+
import sys
17+
from pathlib import Path
18+
19+
20+
sys.path.insert(0, str(Path(__name__).parent.parent.joinpath("scripts").resolve()))
21+
22+
from scripts_utils.config import QUEUE_FILE # noqa: E402
23+
from scripts_utils.console import console # noqa: E402
24+
from scripts_utils.common_ast import parse_methods_info_from_target_class # noqa: E402
25+
26+
27+
def main():
28+
"""Main function."""
29+
30+
module_tree = ast.parse(source=QUEUE_FILE.read_text(), filename=QUEUE_FILE)
31+
_, missing_async = parse_methods_info_from_target_class(
32+
module_tree, target_class="PGMQueue"
33+
)
34+
35+
if not missing_async:
36+
console.print(
37+
"[bold green]SUCCESS:[/bold green] All public methods have corresponding async versions!"
38+
)
39+
sys.exit(0)
40+
41+
# log all the missing async methods
42+
console.print()
43+
console.print(
44+
f"[bold yellow]WARNING:[/bold yellow] Found {len(missing_async)} missing async methods:",
45+
style="bold",
46+
)
47+
for method in missing_async:
48+
console.print(f" [yellow]-[/yellow] {method}_async")
49+
console.print()
50+
51+
sys.exit(1)
52+
53+
54+
if __name__ == "__main__":
55+
main()
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python
2+
# /// script
3+
# requires-python = ">=3.10,<3.11"
4+
# dependencies = [
5+
# "rich>=13.6.0",
6+
# ]
7+
# ///
8+
"""
9+
Script to check for missing async methods in PGMQueue class and generate them.
10+
11+
For each public sync method (not starting with _), checks if there's a corresponding
12+
async method with the same name plus '_async' suffix. If missing, generates it.
13+
"""
14+
15+
import ast
16+
import sys
17+
from pathlib import Path
18+
import contextlib
19+
20+
import tempfile
21+
22+
23+
from scripts_utils.config import QUEUE_FILE
24+
from scripts_utils.console import console
25+
from scripts_utils.common_ast import (
26+
parse_methods_info_from_target_class,
27+
fill_missing_methods_to_class,
28+
)
29+
from scripts_utils.formatting import format_file, compare_file
30+
from scripts_utils.queue_ast import get_async_methods_to_add
31+
32+
33+
def main():
34+
"""Main function."""
35+
36+
module_tree = ast.parse(source=QUEUE_FILE.read_text(), filename=QUEUE_FILE)
37+
sync_methods, missing_async = parse_methods_info_from_target_class(
38+
module_tree, target_class="PGMQueue"
39+
)
40+
41+
if not missing_async:
42+
console.print(
43+
"[bold green]SUCCESS:[/bold green] All public methods have corresponding async versions!"
44+
)
45+
sys.exit(0)
46+
47+
# log all the missing async methods
48+
console.print()
49+
console.print(
50+
f"[bold yellow]WARNING:[/bold yellow] Found {len(missing_async)} missing async methods:",
51+
style="bold",
52+
)
53+
for method in missing_async:
54+
console.print(f" [yellow]-[/yellow] {method}_async")
55+
console.print()
56+
57+
# create missing async method from
58+
async_methods_to_add = get_async_methods_to_add(sync_methods, missing_async)
59+
# insert back to class
60+
fill_missing_methods_to_class(module_tree, "PGMQueue", async_methods_to_add)
61+
module_tree = ast.fix_missing_locations(module_tree)
62+
63+
# write back to tmp file for comparison
64+
tmp_file = ""
65+
with tempfile.NamedTemporaryFile(mode="w+t", delete=False, suffix=".py") as f:
66+
f.write(ast.unparse(module_tree))
67+
f.flush()
68+
tmp_file = f.name
69+
console.log(f"Complete missing async methods at {tmp_file}")
70+
71+
if tmp_file:
72+
max_formatting = 3
73+
for _ in range(max_formatting):
74+
if format_file(tmp_file):
75+
break
76+
77+
_, missing_async_for_tmp = parse_methods_info_from_target_class(
78+
ast.parse(Path(tmp_file).read_text()), "PGMQueue"
79+
)
80+
81+
if missing_async_for_tmp:
82+
console.log(
83+
f"[error]Still get async methods to add after generating missing async methods in {tmp_file}: {missing_async_for_tmp}[/]"
84+
)
85+
else:
86+
console.log("[success]All missing async methods are generated[/]")
87+
88+
# compare existed queue.py and tmp.py
89+
with contextlib.suppress(Exception):
90+
compare_file(QUEUE_FILE, tmp_file)
91+
92+
sys.exit(0)
93+
94+
95+
if __name__ == "__main__":
96+
main()

scripts/scripts_utils/__init__.py

Whitespace-only changes.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import List, Dict, Tuple, Literal
2+
3+
import ast
4+
5+
6+
class MethodInfo:
7+
"""Information about a method."""
8+
9+
def __init__(self, name: str, node: ast.FunctionDef):
10+
self.name = name
11+
self.node = node
12+
self.is_target = not name.startswith(
13+
"_"
14+
) # all the public method is our target method for further processing
15+
self.is_async = name.endswith("_async")
16+
self.base_name = name[:-6] if self.is_async else name
17+
18+
19+
class ParseTargetClassFunctionsVisitor(ast.NodeVisitor):
20+
"""AST visitor to parse functions out of target class name for given module tree"""
21+
22+
def __init__(self, class_name: str):
23+
self.class_name = class_name
24+
self.methods: List[MethodInfo] = []
25+
self.is_cur_node_in_target_class = False
26+
27+
def visit_ClassDef(self, node: ast.ClassDef):
28+
if node.name == self.class_name:
29+
self.is_cur_node_in_target_class = True
30+
self.generic_visit(node)
31+
self.is_cur_node_in_target_class = False
32+
else:
33+
self.generic_visit(node)
34+
35+
def visit_FunctionDef(self, node: ast.FunctionDef):
36+
if self.is_cur_node_in_target_class:
37+
# add all the method to the methods
38+
self.methods.append(MethodInfo(node.name, node))
39+
self.generic_visit(node)
40+
41+
def visit_AsyncFunctionDef(self, node):
42+
if self.is_cur_node_in_target_class:
43+
# add all the method to the methods
44+
self.methods.append(MethodInfo(node.name, node))
45+
self.generic_visit(node)
46+
47+
48+
class FillMissingMethodsToClass(ast.NodeTransformer):
49+
"""AST Transformer to fill missing async_methods back to target class"""
50+
51+
def __init__(self, class_name: str, to_add_async_methods: Dict[str, MethodInfo]):
52+
self.class_name = class_name
53+
self.to_add_async_methods = to_add_async_methods
54+
55+
def visit_ClassDef(self, node: ast.ClassDef):
56+
if node.name == self.class_name:
57+
for sync_func_name, async_func_node in self.to_add_async_methods.items():
58+
idx = next(
59+
(
60+
i
61+
for i, stmt in enumerate(node.body)
62+
if isinstance(stmt, ast.FunctionDef)
63+
and stmt.name == sync_func_name
64+
),
65+
-1,
66+
)
67+
68+
if idx != -1:
69+
node.body.insert(idx + 1, async_func_node.node)
70+
71+
return self.generic_visit(node)
72+
73+
74+
def parse_methods_info_from_target_class(
75+
module_tree: ast.Module, target_class: Literal["PGMQueue", "PGMQOperation"]
76+
) -> Tuple[List[MethodInfo], set[str]]:
77+
"""
78+
Parse methods of target class from give module AST Tree
79+
80+
Args:
81+
module_tree: ast.Module
82+
target_class: either "PGMQueue" or "PGMQOperation" str
83+
84+
Returns:
85+
Tuple of sync_methods, missing_async_set
86+
"""
87+
88+
analyzer = ParseTargetClassFunctionsVisitor(target_class)
89+
analyzer.visit(module_tree)
90+
91+
# Categorize methods
92+
# We use sync methods as source of truth
93+
async_methods_set = set()
94+
missing_async_set = set()
95+
96+
for method_info in analyzer.methods:
97+
# skip non target methods
98+
if not method_info.is_target:
99+
continue
100+
101+
if method_info.is_async:
102+
async_methods_set.add(method_info.base_name)
103+
104+
# Find missing async methods and generate class with interleaved methods
105+
for method_info in analyzer.methods:
106+
# skip non target methods
107+
if not method_info.is_target:
108+
continue
109+
110+
if method_info.base_name not in async_methods_set:
111+
missing_async_set.add(method_info.base_name)
112+
113+
return analyzer.methods, missing_async_set
114+
115+
116+
def fill_missing_methods_to_class(
117+
module_tree: ast.Module,
118+
target_class: Literal["PGMQueue", "PGMQOperation"],
119+
to_add_async_methods: Dict[str, MethodInfo],
120+
):
121+
transformer = FillMissingMethodsToClass(
122+
class_name=target_class, to_add_async_methods=to_add_async_methods
123+
)
124+
transformer.visit(module_tree)

scripts/scripts_utils/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pathlib import Path
2+
3+
PROJECT_ROOT = Path(__file__).parent.parent.parent
4+
SOURCE_PATH = PROJECT_ROOT / "pgmq_sqlalchemy"
5+
QUEUE_FILE = SOURCE_PATH / "queue.py"

scripts/scripts_utils/console.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from rich.console import Console
2+
from rich.theme import Theme
3+
4+
5+
console = Console(
6+
force_terminal=True,
7+
color_system="standard",
8+
theme=Theme(
9+
{
10+
"success": "green",
11+
"info": "bright_blue",
12+
"warning": "bright_yellow",
13+
"error": "red",
14+
"special": "magenta",
15+
}
16+
),
17+
width=202,
18+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import subprocess
2+
import sys
3+
from pathlib import Path
4+
5+
6+
sys.path.insert(0, str(Path(__name__).parent.parent.joinpath("scripts").resolve()))
7+
8+
9+
def format_file(file_path: str) -> bool:
10+
try:
11+
ruff_stdout = subprocess.check_output(["ruff", "format", file_path]).decode()
12+
except Exception as e:
13+
raise e
14+
15+
return "unchanged" in ruff_stdout
16+
17+
18+
def compare_file(existed_file: str, new_file: str):
19+
try:
20+
subprocess.check_call(
21+
["git", "difftool", "--tool=vimdiff", "--no-index", existed_file, new_file]
22+
)
23+
except Exception as e:
24+
raise e

0 commit comments

Comments
 (0)