Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 87 additions & 28 deletions language_tool_python/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,51 @@
import re
import sys
import traceback
from collections.abc import Sequence
from importlib.metadata import PackageNotFoundError, version
from logging.config import dictConfig
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, TypedDict, cast

from ._compat import toml_loads
from .exceptions import LanguageToolError
from .server import LanguageTool

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import TextIO


class _PyProjectProject(TypedDict):
version: str


class _PyProject(TypedDict):
project: _PyProjectProject


def _load_pyproject_and_logconfig(path: Path) -> dict[str, object]:
"""Load a TOML file as a typed dictionary.

:param path: The path to the TOML file to load.
:type path: Path
:return: The contents of the TOML file as a dictionary.
:rtype: dict[str, object]
"""
with path.open("rb") as f:
return cast("dict[str, object]", toml_loads(f.read().decode("utf-8")))


def _read_project_version(pyproject: Path) -> str:
"""Read the package version from pyproject.toml.

:param pyproject: The path to the pyproject.toml file.
:type pyproject: Path
:return: The package version.
:rtype: str
"""
pyproject_config = cast("_PyProject", _load_pyproject_and_logconfig(pyproject))
return pyproject_config["project"]["version"]


try:
__version__ = version("language_tool_python")
Expand All @@ -28,26 +61,43 @@
except PackageNotFoundError:
project_root = Path(__file__).resolve().parent.parent
pyproject = project_root / "pyproject.toml"
with pyproject.open("rb") as f:
__version__ = toml_loads(f.read().decode("utf-8"))["project"]["version"]
__version__ = _read_project_version(pyproject)


logger = logging.getLogger(__name__)
with (
importlib.resources.as_file(
importlib.resources.files("language_tool_python").joinpath("logging.toml"),
) as config_path,
config_path.open("rb") as f,
):
log_config = toml_loads(f.read().decode("utf-8"))
with importlib.resources.as_file(
importlib.resources.files("language_tool_python").joinpath("logging.toml"),
) as config_path:
log_config = _load_pyproject_and_logconfig(config_path)
dictConfig(log_config)

RULE_RE: re.Pattern[str] = re.compile(r"[\w-]+")


def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
class CliArgs(argparse.Namespace):
"""Typed command-line arguments."""

files: list[str]
encoding: str | None
language: str | None
mother_tongue: str | None
disable: set[str]
enable: set[str]
enabled_only: bool
picky: bool
apply: bool
spell_check: bool
ignore_lines: str | None
remote_host: str | None
remote_port: str | None
verbose: bool


def parse_args(argv: Sequence[str] | None = None) -> CliArgs:
"""Parse command line arguments.

:return: parsed arguments
:rtype: argparse.Namespace
:rtype: CliArgs
"""
parser = argparse.ArgumentParser(
description=__doc__.strip() if __doc__ else None,
Expand All @@ -73,7 +123,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
metavar="RULES",
type=get_rules,
action=RulesAction,
default=set(),
default=set[str](),
help="list of rule IDs to be disabled",
)
parser.add_argument(
Expand All @@ -82,7 +132,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
metavar="RULES",
type=get_rules,
action=RulesAction,
default=set(),
default=set[str](),
help="list of rule IDs to be enabled",
)
parser.add_argument(
Expand Down Expand Up @@ -126,7 +176,8 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
parser.add_argument("--remote-port", help="port of the remote LanguageTool server")
parser.add_argument("--verbose", action="store_true", help="enable verbose output")

args = parser.parse_args(argv)
args = CliArgs()
parser.parse_args(argv, namespace=args)

if args.enabled_only:
if args.disable:
Expand Down Expand Up @@ -165,15 +216,21 @@ def __call__(
:param _parser: The ArgumentParser object which contains this action.
:type _parser: argparse.ArgumentParser
:param namespace: The namespace object that will be returned by parse_args().
:type namespace: argparse.Namespace
:type namespace: CliArgs
:param values: The argument values associated with the action.
:type values: str | Sequence[object] | None
:param _option_string: The option string that was used to invoke this action.
:type _option_string: str | None
"""
getattr(namespace, self.dest).update(
cast("set[str]", values),
)
cli_args = cast("CliArgs", namespace)
rule_values = cast("set[str]", values)
if self.dest == "disable":
cli_args.disable.update(rule_values)
elif self.dest == "enable":
cli_args.enable.update(rule_values)
else:
err = f"unexpected rules destination: {self.dest}"
raise ValueError(err)


def get_rules(rules: str) -> set[str]:
Expand All @@ -184,7 +241,8 @@ def get_rules(rules: str) -> set[str]:
:return: A set of rule IDs.
:rtype: set[str]
"""
return {rule.upper() for rule in re.findall(r"[\w\-]+", rules)}
rule_ids = cast("list[str]", RULE_RE.findall(rules))
return {rule.upper() for rule in rule_ids}


def get_text(
Expand Down Expand Up @@ -223,11 +281,11 @@ def print_exception(exc: Exception, debug: bool) -> None:
print(exc, file=sys.stderr)


def get_remote_server(args: argparse.Namespace) -> str | None:
def get_remote_server(args: CliArgs) -> str | None:
"""Build the remote server address from parsed arguments.

:param args: Parsed command-line arguments.
:type args: argparse.Namespace
:type args: CliArgs
:return: The remote server address in the format "host:port" or None if no remote
host is specified.
:rtype: str | None
Expand All @@ -242,18 +300,19 @@ def get_remote_server(args: argparse.Namespace) -> str | None:
return remote_server


def get_input_text(filename: str, args: argparse.Namespace) -> str:
def get_input_text(filename: str, args: CliArgs) -> str:
"""Read input text from a file or stdin.

:param filename: The name of the file to read or "-" for stdin.
:type filename: str
:param args: Parsed command-line arguments.
:type args: argparse.Namespace
:type args: CliArgs
:return: The input text as a string.
:rtype: str
"""
if filename == "-":
raw = sys.stdin.read()
stdin = cast("TextIO", sys.stdin)
raw = stdin.read()
if args.ignore_lines:
return "".join(
"\n" if re.match(args.ignore_lines, line) else line
Expand All @@ -267,15 +326,15 @@ def get_input_text(filename: str, args: argparse.Namespace) -> str:

def process_file(
filename: str,
args: argparse.Namespace,
args: CliArgs,
remote_server: str | None,
) -> int:
"""Check a single input file and return the resulting status.

:param filename: The name of the file to check or "-" for stdin.
:type filename: str
:param args: Parsed command-line arguments.
:type args: argparse.Namespace
:type args: CliArgs
:param remote_server: The remote server address or None.
:type remote_server: str | None
:return: The resulting status.
Expand Down
11 changes: 9 additions & 2 deletions language_tool_python/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
- ``deprecated``: built-in ``warnings.deprecated`` on Python 3.13+, otherwise
``typing_extensions.deprecated``.
- ``toml_loads``: built-in ``tomllib.loads`` on Python 3.11+, otherwise
``tomli.loads``.
``tomli.loads``.
- ``TypeGuard``: built-in ``typing.TypeGuard`` on Python 3.10+, otherwise
``typing_extensions.TypeGuard``.
"""

import sys

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

if sys.version_info >= (3, 11):
from tomllib import loads as toml_loads
else:
Expand All @@ -20,4 +27,4 @@
else:
from typing_extensions import deprecated

__all__ = ["deprecated", "toml_loads"]
__all__ = ["TypeGuard", "deprecated", "toml_loads"]
17 changes: 10 additions & 7 deletions language_tool_python/api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from __future__ import annotations

from typing import TypedDict
from typing import TYPE_CHECKING, TypedDict

if TYPE_CHECKING:
from ._compat import TypeGuard

__all__ = [
"Category",
Expand All @@ -29,13 +32,13 @@ class LanguageInfo(TypedDict):
name: str


def is_language_info(value: object) -> bool: # No TypeGuard because py3.9
def is_language_info(value: object) -> TypeGuard[LanguageInfo]:
"""Verify that a value is a LanguageInfo.

:param value: The value to check.
:type value: object
:return: True if the value is a LanguageInfo, False otherwise.
:rtype: bool
:return: TypeGuard indicating whether the value is a LanguageInfo.
:rtype: TypeGuard[LanguageInfo]
"""
if not isinstance(value, dict):
return False
Expand Down Expand Up @@ -139,13 +142,13 @@ class CheckResponse(TypedDict):
warnings: WarningInfo


def is_check_response(value: object) -> bool: # No TypeGuard because py3.9
def is_check_response(value: object) -> TypeGuard[CheckResponse]:
"""Verify that a value is a CheckResponse.

:param value: The value to check.
:type value: object
:return: True if the value is a CheckResponse, False otherwise.
:rtype: bool
:return: TypeGuard indicating whether the value is a CheckResponse.
:rtype: TypeGuard[CheckResponse]
"""
if not isinstance(value, dict):
return False
Expand Down
Loading
Loading