diff --git a/language_tool_python/__main__.py b/language_tool_python/__main__.py index b94ccb9..8b03086 100644 --- a/language_tool_python/__main__.py +++ b/language_tool_python/__main__.py @@ -8,11 +8,10 @@ import re import sys import traceback -from collections.abc import Sequence from importlib.metadata import PackageNotFoundError, version from logging.config import dictConfig from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, TypedDict, cast from ._compat import toml_loads from .exceptions import LanguageToolError @@ -20,6 +19,40 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing import TextIO + + +class _PyProjectProject(TypedDict): + version: str + + +class _PyProject(TypedDict): + project: _PyProjectProject + + +def _load_pyproject_and_logconfig(path: Path) -> dict[str, object]: + """Load a TOML file as a typed dictionary. + + :param path: The path to the TOML file to load. + :type path: Path + :return: The contents of the TOML file as a dictionary. + :rtype: dict[str, object] + """ + with path.open("rb") as f: + return cast("dict[str, object]", toml_loads(f.read().decode("utf-8"))) + + +def _read_project_version(pyproject: Path) -> str: + """Read the package version from pyproject.toml. + + :param pyproject: The path to the pyproject.toml file. + :type pyproject: Path + :return: The package version. + :rtype: str + """ + pyproject_config = cast("_PyProject", _load_pyproject_and_logconfig(pyproject)) + return pyproject_config["project"]["version"] + try: __version__ = version("language_tool_python") @@ -28,26 +61,43 @@ except PackageNotFoundError: project_root = Path(__file__).resolve().parent.parent pyproject = project_root / "pyproject.toml" - with pyproject.open("rb") as f: - __version__ = toml_loads(f.read().decode("utf-8"))["project"]["version"] + __version__ = _read_project_version(pyproject) logger = logging.getLogger(__name__) -with ( - importlib.resources.as_file( - importlib.resources.files("language_tool_python").joinpath("logging.toml"), - ) as config_path, - config_path.open("rb") as f, -): - log_config = toml_loads(f.read().decode("utf-8")) +with importlib.resources.as_file( + importlib.resources.files("language_tool_python").joinpath("logging.toml"), +) as config_path: + log_config = _load_pyproject_and_logconfig(config_path) dictConfig(log_config) +RULE_RE: re.Pattern[str] = re.compile(r"[\w-]+") + -def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: +class CliArgs(argparse.Namespace): + """Typed command-line arguments.""" + + files: list[str] + encoding: str | None + language: str | None + mother_tongue: str | None + disable: set[str] + enable: set[str] + enabled_only: bool + picky: bool + apply: bool + spell_check: bool + ignore_lines: str | None + remote_host: str | None + remote_port: str | None + verbose: bool + + +def parse_args(argv: Sequence[str] | None = None) -> CliArgs: """Parse command line arguments. :return: parsed arguments - :rtype: argparse.Namespace + :rtype: CliArgs """ parser = argparse.ArgumentParser( description=__doc__.strip() if __doc__ else None, @@ -73,7 +123,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: metavar="RULES", type=get_rules, action=RulesAction, - default=set(), + default=set[str](), help="list of rule IDs to be disabled", ) parser.add_argument( @@ -82,7 +132,7 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: metavar="RULES", type=get_rules, action=RulesAction, - default=set(), + default=set[str](), help="list of rule IDs to be enabled", ) parser.add_argument( @@ -126,7 +176,8 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser.add_argument("--remote-port", help="port of the remote LanguageTool server") parser.add_argument("--verbose", action="store_true", help="enable verbose output") - args = parser.parse_args(argv) + args = CliArgs() + parser.parse_args(argv, namespace=args) if args.enabled_only: if args.disable: @@ -165,15 +216,21 @@ def __call__( :param _parser: The ArgumentParser object which contains this action. :type _parser: argparse.ArgumentParser :param namespace: The namespace object that will be returned by parse_args(). - :type namespace: argparse.Namespace + :type namespace: CliArgs :param values: The argument values associated with the action. :type values: str | Sequence[object] | None :param _option_string: The option string that was used to invoke this action. :type _option_string: str | None """ - getattr(namespace, self.dest).update( - cast("set[str]", values), - ) + cli_args = cast("CliArgs", namespace) + rule_values = cast("set[str]", values) + if self.dest == "disable": + cli_args.disable.update(rule_values) + elif self.dest == "enable": + cli_args.enable.update(rule_values) + else: + err = f"unexpected rules destination: {self.dest}" + raise ValueError(err) def get_rules(rules: str) -> set[str]: @@ -184,7 +241,8 @@ def get_rules(rules: str) -> set[str]: :return: A set of rule IDs. :rtype: set[str] """ - return {rule.upper() for rule in re.findall(r"[\w\-]+", rules)} + rule_ids = cast("list[str]", RULE_RE.findall(rules)) + return {rule.upper() for rule in rule_ids} def get_text( @@ -223,11 +281,11 @@ def print_exception(exc: Exception, debug: bool) -> None: print(exc, file=sys.stderr) -def get_remote_server(args: argparse.Namespace) -> str | None: +def get_remote_server(args: CliArgs) -> str | None: """Build the remote server address from parsed arguments. :param args: Parsed command-line arguments. - :type args: argparse.Namespace + :type args: CliArgs :return: The remote server address in the format "host:port" or None if no remote host is specified. :rtype: str | None @@ -242,18 +300,19 @@ def get_remote_server(args: argparse.Namespace) -> str | None: return remote_server -def get_input_text(filename: str, args: argparse.Namespace) -> str: +def get_input_text(filename: str, args: CliArgs) -> str: """Read input text from a file or stdin. :param filename: The name of the file to read or "-" for stdin. :type filename: str :param args: Parsed command-line arguments. - :type args: argparse.Namespace + :type args: CliArgs :return: The input text as a string. :rtype: str """ if filename == "-": - raw = sys.stdin.read() + stdin = cast("TextIO", sys.stdin) + raw = stdin.read() if args.ignore_lines: return "".join( "\n" if re.match(args.ignore_lines, line) else line @@ -267,7 +326,7 @@ def get_input_text(filename: str, args: argparse.Namespace) -> str: def process_file( filename: str, - args: argparse.Namespace, + args: CliArgs, remote_server: str | None, ) -> int: """Check a single input file and return the resulting status. @@ -275,7 +334,7 @@ def process_file( :param filename: The name of the file to check or "-" for stdin. :type filename: str :param args: Parsed command-line arguments. - :type args: argparse.Namespace + :type args: CliArgs :param remote_server: The remote server address or None. :type remote_server: str | None :return: The resulting status. diff --git a/language_tool_python/_compat.py b/language_tool_python/_compat.py index 5baad9d..fae7ba9 100644 --- a/language_tool_python/_compat.py +++ b/language_tool_python/_compat.py @@ -5,11 +5,18 @@ - ``deprecated``: built-in ``warnings.deprecated`` on Python 3.13+, otherwise ``typing_extensions.deprecated``. - ``toml_loads``: built-in ``tomllib.loads`` on Python 3.11+, otherwise -``tomli.loads``. + ``tomli.loads``. +- ``TypeGuard``: built-in ``typing.TypeGuard`` on Python 3.10+, otherwise + ``typing_extensions.TypeGuard``. """ import sys +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + if sys.version_info >= (3, 11): from tomllib import loads as toml_loads else: @@ -20,4 +27,4 @@ else: from typing_extensions import deprecated -__all__ = ["deprecated", "toml_loads"] +__all__ = ["TypeGuard", "deprecated", "toml_loads"] diff --git a/language_tool_python/api_types.py b/language_tool_python/api_types.py index be04427..65d244c 100644 --- a/language_tool_python/api_types.py +++ b/language_tool_python/api_types.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict + +if TYPE_CHECKING: + from ._compat import TypeGuard __all__ = [ "Category", @@ -29,13 +32,13 @@ class LanguageInfo(TypedDict): name: str -def is_language_info(value: object) -> bool: # No TypeGuard because py3.9 +def is_language_info(value: object) -> TypeGuard[LanguageInfo]: """Verify that a value is a LanguageInfo. :param value: The value to check. :type value: object - :return: True if the value is a LanguageInfo, False otherwise. - :rtype: bool + :return: TypeGuard indicating whether the value is a LanguageInfo. + :rtype: TypeGuard[LanguageInfo] """ if not isinstance(value, dict): return False @@ -139,13 +142,13 @@ class CheckResponse(TypedDict): warnings: WarningInfo -def is_check_response(value: object) -> bool: # No TypeGuard because py3.9 +def is_check_response(value: object) -> TypeGuard[CheckResponse]: """Verify that a value is a CheckResponse. :param value: The value to check. :type value: object - :return: True if the value is a CheckResponse, False otherwise. - :rtype: bool + :return: TypeGuard indicating whether the value is a CheckResponse. + :rtype: TypeGuard[CheckResponse] """ if not isinstance(value, dict): return False diff --git a/language_tool_python/config_file.py b/language_tool_python/config_file.py index 68afd8e..d2c374f 100644 --- a/language_tool_python/config_file.py +++ b/language_tool_python/config_file.py @@ -5,19 +5,18 @@ import atexit import logging import tempfile -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass from os import PathLike from pathlib import Path -from typing import Callable, Generic, TypeVar, Union, cast +from typing import TypeVar, Union, cast from .exceptions import PathError from .utils import SupportsBool # Union here because | not supported by PathLike in py3.9 ConfigValue = Union[PathLike[str], SupportsBool, str, int, float, Iterable[str]] - -ConfigValueT_contra = TypeVar("ConfigValueT_contra", contravariant=True) +ConfigValueT = TypeVar("ConfigValueT", bound=ConfigValue) logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ def _reject_line_breaks(field_name: str, value: str) -> None: @dataclass(frozen=True) -class OptionSpec(Generic[ConfigValueT_contra]): +class OptionSpec: """Specification for a configuration option. This class defines the structure and behavior of a configuration option, including @@ -57,16 +56,29 @@ class OptionSpec(Generic[ConfigValueT_contra]): constant throughout the application lifecycle. """ - py_types: type | tuple[type, ...] + py_types: type[object] | tuple[type[object], ...] """The Python type(s) that this option accepts.""" - encoder: Callable[[ConfigValueT_contra], str] + encoder: Callable[[ConfigValue], str] """A callable that converts the option value to its string representation.""" - validator: Callable[[ConfigValueT_contra], None] | None = None + validator: Callable[[ConfigValue], None] | None = None """An optional validator function for the option value.""" +def _option_spec( + py_types: type[object] | tuple[type[object], ...], + encoder: Callable[[ConfigValueT], str], + validator: Callable[[ConfigValueT], None] | None = None, +) -> OptionSpec: + """Create a schema entry for a runtime-checked configuration option.""" + return OptionSpec( + py_types=py_types, + encoder=cast("Callable[[ConfigValue], str]", encoder), + validator=cast("Callable[[ConfigValue], None] | None", validator), + ) + + def _bool_encoder(v: SupportsBool) -> str: """Encode a value as a lowercase boolean string. @@ -141,37 +153,34 @@ def _path_validator(v: PathLike[str] | str) -> None: raise PathError(err) -CONFIG_SCHEMA = cast( - "dict[str, OptionSpec[ConfigValue]]", - { - "maxTextLength": OptionSpec(int, _int_encoder), - "maxTextHardLength": OptionSpec(int, _int_encoder), - "maxCheckTimeMillis": OptionSpec(int, _int_encoder), - "maxErrorsPerWordRate": OptionSpec((int, float), _number_encoder), - "maxSpellingSuggestions": OptionSpec(int, _int_encoder), - "maxCheckThreads": OptionSpec(int, _int_encoder), - "cacheSize": OptionSpec(int, _int_encoder), - "cacheTTLSeconds": OptionSpec(int, _int_encoder), - "requestLimit": OptionSpec(int, _int_encoder), - "requestLimitInBytes": OptionSpec(int, _int_encoder), - "timeoutRequestLimit": OptionSpec(int, _int_encoder), - "requestLimitPeriodInSeconds": OptionSpec(int, _int_encoder), - "languageModel": OptionSpec((str, Path), _path_encoder, _path_validator), - "fasttextModel": OptionSpec((str, Path), _path_encoder, _path_validator), - "fasttextBinary": OptionSpec((str, Path), _path_encoder, _path_validator), - "maxWorkQueueSize": OptionSpec(int, _int_encoder), - "rulesFile": OptionSpec((str, Path), _path_encoder, _path_validator), - "blockedReferrers": OptionSpec((str, list, tuple, set), _comma_list_encoder), - "premiumOnly": OptionSpec((bool, int), _bool_encoder), - "disabledRuleIds": OptionSpec((str, list, tuple, set), _comma_list_encoder), - "pipelineCaching": OptionSpec((bool, int), _bool_encoder), - "maxPipelinePoolSize": OptionSpec(int, _int_encoder), - "pipelineExpireTimeInSeconds": OptionSpec(int, _int_encoder), - "pipelinePrewarming": OptionSpec((bool, int), _bool_encoder), - "trustXForwardForHeader": OptionSpec((bool, int), _bool_encoder), - "suggestionsEnabled": OptionSpec((bool, int), _bool_encoder), - }, -) +CONFIG_SCHEMA: dict[str, OptionSpec] = { + "maxTextLength": _option_spec(int, _int_encoder), + "maxTextHardLength": _option_spec(int, _int_encoder), + "maxCheckTimeMillis": _option_spec(int, _int_encoder), + "maxErrorsPerWordRate": _option_spec((int, float), _number_encoder), + "maxSpellingSuggestions": _option_spec(int, _int_encoder), + "maxCheckThreads": _option_spec(int, _int_encoder), + "cacheSize": _option_spec(int, _int_encoder), + "cacheTTLSeconds": _option_spec(int, _int_encoder), + "requestLimit": _option_spec(int, _int_encoder), + "requestLimitInBytes": _option_spec(int, _int_encoder), + "timeoutRequestLimit": _option_spec(int, _int_encoder), + "requestLimitPeriodInSeconds": _option_spec(int, _int_encoder), + "languageModel": _option_spec((str, Path), _path_encoder, _path_validator), + "fasttextModel": _option_spec((str, Path), _path_encoder, _path_validator), + "fasttextBinary": _option_spec((str, Path), _path_encoder, _path_validator), + "maxWorkQueueSize": _option_spec(int, _int_encoder), + "rulesFile": _option_spec((str, Path), _path_encoder, _path_validator), + "blockedReferrers": _option_spec((str, list, tuple, set), _comma_list_encoder), + "premiumOnly": _option_spec((bool, int), _bool_encoder), + "disabledRuleIds": _option_spec((str, list, tuple, set), _comma_list_encoder), + "pipelineCaching": _option_spec((bool, int), _bool_encoder), + "maxPipelinePoolSize": _option_spec(int, _int_encoder), + "pipelineExpireTimeInSeconds": _option_spec(int, _int_encoder), + "pipelinePrewarming": _option_spec((bool, int), _bool_encoder), + "trustXForwardForHeader": _option_spec((bool, int), _bool_encoder), + "suggestionsEnabled": _option_spec((bool, int), _bool_encoder), +} def _is_lang_key(key: str) -> bool: diff --git a/language_tool_python/download_lt.py b/language_tool_python/download_lt.py index 5defef8..717f02e 100755 --- a/language_tool_python/download_lt.py +++ b/language_tool_python/download_lt.py @@ -16,7 +16,7 @@ from functools import total_ordering from pathlib import Path from shutil import which -from typing import IO, TYPE_CHECKING +from typing import IO, TYPE_CHECKING, cast from urllib.parse import urljoin from warnings import warn @@ -35,6 +35,7 @@ ) if TYPE_CHECKING: + from collections.abc import Mapping from types import NotImplementedType from .config_file import LanguageToolConfig @@ -76,13 +77,50 @@ DOWNLOAD_CHUNK_BYTES = 1024 * 1024 _SAFE_ZIP_EXTRACTOR = SafeZipExtractor() + +def _loads_manifest(raw_manifest: str) -> object: + """Load the integrity manifest from a raw TOML string. + + :param raw_manifest: The raw TOML string containing the integrity manifest. + :type raw_manifest: str + :return: The parsed manifest as a Python object. + :rtype: object + """ + return cast("object", toml_loads(raw_manifest)) + + +def _load_expected_download_sha256(raw_manifest: str) -> dict[str, str]: + """Load and validate the bundled download checksum manifest. + + :param raw_manifest: The raw TOML string containing the integrity manifest. + :type raw_manifest: str + :return: A dictionary mapping version names to their expected SHA-256 hashes. + :rtype: dict[str, str] + """ + parsed = _loads_manifest(raw_manifest) + if not isinstance(parsed, dict): + err = "Invalid integrity manifest: expected a TOML table." + raise PathError(err) + + manifest = cast("Mapping[object, object]", parsed) + expected_hashes: dict[str, str] = {} + for version_name, checksum in manifest.items(): + if not isinstance(version_name, str) or not isinstance(checksum, str): + err = "Invalid integrity manifest: expected string keys and values." + raise PathError(err) + expected_hashes[version_name] = checksum + return expected_hashes + + with ( importlib.resources.as_file( importlib.resources.files("language_tool_python").joinpath("integrity.toml"), ) as hashes_path, hashes_path.open("rb") as f, ): - EXPECTED_DOWNLOAD_SHA256 = toml_loads(f.read().decode("utf-8")) + EXPECTED_DOWNLOAD_SHA256 = _load_expected_download_sha256( + f.read().decode("utf-8"), + ) JAVA_VERSION_REGEX = re.compile( r'^(?:java|openjdk) version "(?P\d+)(|\.(?P\d+)\.[^"]+)"', @@ -310,7 +348,11 @@ def http_get( :raises PathError: If the download fails or checksum validation fails. """ version_match = re.search(r"LanguageTool-(.+)\.zip", url) - version_name = version_match.group(1) if version_match else LTP_DOWNLOAD_VERSION + if version_match: + version_start, version_end = version_match.span(1) + version_name = url[version_start:version_end] + else: + version_name = LTP_DOWNLOAD_VERSION # Normalize snapshot-style version names (e.g. "6.8-SNAPSHOT", "latest-snapshot") if version_name.lower().endswith("-snapshot"): diff --git a/language_tool_python/language_tag.py b/language_tool_python/language_tag.py index 6130db9..8bac161 100644 --- a/language_tool_python/language_tag.py +++ b/language_tool_python/language_tag.py @@ -131,7 +131,9 @@ def _normalize(self, tag: str) -> str: err = "tag does not match pattern" raise AttributeError(err) from e logger.debug("Regex match groups: %s", match.groups()) - return languages[match.group(1).lower()] + language_start, language_end = match.span(1) + language = tag[language_start:language_end].lower() + return languages[language] except (KeyError, AttributeError) as e: err = f"unsupported language: {tag!r}" raise ValueError(err) from e diff --git a/language_tool_python/match.py b/language_tool_python/match.py index abd239f..d95cd83 100644 --- a/language_tool_python/match.py +++ b/language_tool_python/match.py @@ -7,7 +7,7 @@ from collections import OrderedDict from collections import OrderedDict as OrderedDictType from functools import total_ordering -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from ._compat import deprecated from .utils import SupportsFloat, SupportsInt @@ -22,6 +22,7 @@ UTF8_4_BYTE_LENGTH = 4 CONTEXT_PREFIX_SUFFIX_LENGTH = 3 CONTEXT_WITH_ADDITIONS_MIN_LENGTH = 6 +MatchValue = Union[str, int, list[str]] # | operator not fully supported by py3.9 def get_match_ordered_dict() -> OrderedDictType[str, type]: @@ -205,6 +206,9 @@ class Match: # noqa: PLW1641 # Doesn't implement hash because it's mutable rule_issue_type: str """The issue type of the rule that was violated.""" + sentence: str + """The sentence that contains the rule violation.""" + def __init__(self, attrib: CheckMatch, text: str) -> None: """Initialize a Match object with the given attributes. @@ -244,6 +248,21 @@ def __init__(self, attrib: CheckMatch, text: str) -> None: 1 for pos in Match.FOUR_BYTES_POSITIONS if pos < self.offset ) + def _ordered_items(self) -> list[tuple[str, MatchValue]]: + """Return public match attributes in the documented order.""" + return [ + ("rule_id", self.rule_id), + ("message", self.message), + ("replacements", self.replacements), + ("offset_in_context", self.offset_in_context), + ("context", self.context), + ("offset", self.offset), + ("error_length", self.error_length), + ("category", self.category), + ("rule_issue_type", self.rule_issue_type), + ("sentence", self.sentence), + ] + def __repr__(self) -> str: """Return a string representation of the object. @@ -267,14 +286,10 @@ def _ordered_dict_repr() -> str: dictionary format. :rtype: str """ - slots = list(get_match_ordered_dict()) - slots += list(set(self.__dict__).difference(slots)) - attrs = [ - slot - for slot in slots - if slot in self.__dict__ and not slot.startswith("_") - ] - return f"{{{', '.join([f'{attr!r}: {getattr(self, attr)!r}' for attr in attrs])}}}" # noqa: E501 # Difficult to break this line in python 3.9 + items = ", ".join( + f"{attr!r}: {value!r}" for attr, value in self._ordered_items() + ) + return f"{{{items}}}" return f"{self.__class__.__name__}({_ordered_dict_repr()})" @@ -376,7 +391,7 @@ def __lt__(self, other: object) -> bool: return NotImplemented return list(self) < list(other) - def __iter__(self) -> Iterator[str | int | list[str]]: + def __iter__(self) -> Iterator[MatchValue]: """Return an iterator over the attributes of the match object. This method allows the match object to be iterated over, yielding the values of @@ -385,9 +400,9 @@ def __iter__(self) -> Iterator[str | int | list[str]]: :return: An iterator over the attribute values of the match object. :rtype: Iterator[str | int | list[str]] """ - return iter(getattr(self, attr) for attr in get_match_ordered_dict()) + return iter(value for _, value in self._ordered_items()) - def __setattr__(self, key: str, value: str | int | list[str]) -> None: + def __setattr__(self, key: str, value: MatchValue) -> None: """Set an attribute on the instance. This method overrides the default behavior of setting an attribute. It attempts diff --git a/language_tool_python/server.py b/language_tool_python/server.py index 9e18471..84da34a 100644 --- a/language_tool_python/server.py +++ b/language_tool_python/server.py @@ -7,15 +7,15 @@ import http.client import json import logging -import os import random import re import socket import subprocess +import sys import time import urllib.parse import warnings -from typing import TYPE_CHECKING, ClassVar, Literal, cast +from typing import TYPE_CHECKING, ClassVar, Literal import psutil import requests @@ -43,16 +43,16 @@ parse_url, ) -startupinfo: object = None -if os.name == "nt": - from .utils import startupinfo +startupinfo: object | None = None +if sys.platform == "win32": + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW if TYPE_CHECKING: from collections.abc import Mapping from pathlib import Path from types import TracebackType - from .api_types import CheckResponse, LanguageInfo from .config_file import ConfigValue logger = logging.getLogger(__name__) @@ -82,6 +82,19 @@ def _kill_processes(processes: list[subprocess.Popen[str]]) -> None: p.wait(timeout=5) +def _match_offset(match: Match) -> int: + """Return a match offset for sorting.""" + return match.offset + + +def _decode_response_content(response: requests.Response) -> str: + """Decode response content from bytes to text.""" + content: object = response.content + if isinstance(content, bytes): + return content.decode() + return str(content) + + class LanguageTool: """Interact with the LanguageTool server for text checking and correction. @@ -651,17 +664,13 @@ def check(self, text: str) -> list[Match]: """ url = urllib.parse.urljoin(self._url, "check") logger.debug("Sending text to LanguageTool server at %s", url) - raw_response = self._query_server(url, self._create_params(text), method="post") - if raw_response is None: + response = self._query_server(url, self._create_params(text), method="post") + if response is None: err = "No response received from the LanguageTool server." raise ServerError(err) - if not is_check_response(raw_response): - err = ( - f"Invalid response received from the " - f"LanguageTool server: {raw_response}" - ) + if not is_check_response(response): + err = f"Invalid response received from the LanguageTool server: {response}" raise ServerError(err) - response = cast("CheckResponse", raw_response) matches = response["matches"] return [Match(match, text) for match in matches] @@ -703,7 +712,7 @@ def check_matching_regions( all_matches.extend(region_matches) - return sorted(all_matches, key=lambda m: m.offset) + return sorted(all_matches, key=_match_offset) def _create_params(self, text: str) -> dict[str, str]: """Create a dictionary of parameters for the language tool server request. @@ -902,14 +911,13 @@ def _get_languages(self) -> set[str]: ) raise ServerError(err) if isinstance(raw_languages_response, list): - for raw_lang in raw_languages_response: - if not is_language_info(raw_lang): + for lang in raw_languages_response: + if not is_language_info(lang): err = ( "Unexpected response format when fetching languages from the " "LanguageTool server." ) raise ServerError(err) - lang = cast("LanguageInfo", raw_lang) languages.add(lang["code"]) languages.add(lang["longCode"]) else: @@ -998,7 +1006,9 @@ def _query_server( "LanguageTool API. Please try again later." ) raise RateLimitError(err) from e - raise LanguageToolError(response.content.decode()) from e + raise LanguageToolError( + _decode_response_content(response), + ) from e else: return data except (OSError, http.client.HTTPException) as e: # noqa: PERF203 # it is intentional to catch exceptions in a loop, to retry the request in case of transient errors diff --git a/language_tool_python/utils.py b/language_tool_python/utils.py index 38e41b5..4dae297 100644 --- a/language_tool_python/utils.py +++ b/language_tool_python/utils.py @@ -7,7 +7,6 @@ import logging import math import os -import subprocess import urllib.parse from enum import Enum from pathlib import Path @@ -37,15 +36,6 @@ # Directory containing the LanguageTool jar file: LTP_JAR_DIR_PATH_ENV_VAR = "LTP_JAR_DIR_PATH" -if os.name == "nt": - # Gets STARTUPINFO dynamically to avoid issues on non-Windows platforms - startupinfo_cls = getattr(subprocess, "STARTUPINFO", None) - if startupinfo_cls is not None: - si = startupinfo_cls() - # STARTF_USESHOWWINDOW also dynamically retrieved - si.dwFlags |= getattr(subprocess, "STARTF_USESHOWWINDOW", 0) - startupinfo = si - def parse_url(url_str: str) -> str: """Parse the given URL string and ensure it has a scheme. @@ -376,13 +366,15 @@ def get_jar_info() -> tuple[Path, Path]: # Use the env var to the jar directory if it is defined # otherwise look in the download directory - jar_dir_name = os.environ.get( - LTP_JAR_DIR_PATH_ENV_VAR, - get_language_tool_directory(), + configured_jar_dir = os.environ.get(LTP_JAR_DIR_PATH_ENV_VAR) + jar_dir_path = ( + Path(configured_jar_dir) + if configured_jar_dir is not None + else get_language_tool_directory() ) jar_path = None for jar_name in JAR_NAMES: - for jar_path in Path(jar_dir_name).glob(jar_name): + for jar_path in jar_dir_path.glob(jar_name): if jar_path.is_file(): logger.debug("Found LanguageTool JAR: %s", jar_path) break @@ -391,7 +383,7 @@ def get_jar_info() -> tuple[Path, Path]: if jar_path: break else: - err = f"can't find languagetool-standalone in {jar_dir_name!r}" + err = f"can't find languagetool-standalone in {jar_dir_path!r}" raise PathError(err) return java_path, jar_path diff --git a/pyproject.toml b/pyproject.toml index d263f86..e54ce22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,9 @@ dependencies = [ "packaging", "psutil", "tomli; python_version < '3.11'", # only needed for py < 3.11 because tomllib added in 3.11 is used in the codebase, needs a fallback - "typing_extensions; python_version < '3.13'", # only needed for py < 3.13 because warnings.deprecated added in 3.13 is used in the codebase + # needed for py < 3.13 because warnings.deprecated added in 3.13 is used in the codebase + # also needed for py < 3.10 because of the usage of typing.TypeGuard in the codebase, which is backported to typing_extensions for py < 3.10 + "typing_extensions; python_version < '3.13'", ] [project.urls] @@ -149,7 +151,7 @@ ignore = [ [tool.mypy] files = ["language_tool_python", "tests"] -# disallow_any_expr = true +disallow_any_expr = true disallow_any_generics = true disallow_any_unimported = true disallow_subclassing_any = true diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 4deebd8..4d312f2 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -20,7 +20,7 @@ def old_function() -> str: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = old_function() + result: str = old_function() assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) @@ -37,7 +37,7 @@ def old_function() -> int: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = old_function() + result: int = old_function() assert len(w) == 1 assert issubclass(w[0].category, UserWarning) @@ -92,7 +92,9 @@ def complex_function( with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = complex_function(1, 2, 3, 4, c=5, d=6, e=7) + result: tuple[int, int, tuple[int, ...], int | None, dict[str, int]] = ( + complex_function(1, 2, 3, 4, c=5, d=6, e=7) + ) assert len(w) == 1 assert result == (1, 2, (3, 4), 5, {"d": 6, "e": 7}) diff --git a/tests/test_download.py b/tests/test_download.py index 04b283a..b73bc13 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from datetime import datetime, timedelta, timezone from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest @@ -50,6 +50,17 @@ def iter_content(self, chunk_size: int) -> Iterator[bytes]: yield self.payload[index : index + chunk_size] +class FixedDatetime: + """Datetime replacement returning a configurable UTC datetime.""" + + current_datetime = datetime(2024, 5, 14, tzinfo=timezone.utc) + + @classmethod + def now(cls, _tz: timezone) -> datetime: + """Return the configured test datetime.""" + return cls.current_datetime + + def make_zip_payload(files: dict[str, bytes]) -> bytes: """Create an in-memory ZIP payload for download tests.""" buffer = io.BytesIO() @@ -59,6 +70,10 @@ def make_zip_payload(files: dict[str, bytes]) -> bytes: return buffer.getvalue() +def skip_java_compatibility_check(_language_tool_version: str) -> None: + """Skip Java compatibility checks in download-only tests.""" + + @contextmanager def workspace_temp_dir() -> Iterator[Path]: """Create a temporary directory inside the repository workspace.""" @@ -121,8 +136,7 @@ def test_http_get_403_forbidden() -> None: :raises AssertionError: If PathError is not raised for a 403 status code. """ - mock_response = MagicMock() - mock_response.status_code = 403 + mock_response = MockDownloadResponse(b"", status_code=403) mock_response.headers = {} out_file = io.BytesIO() @@ -148,8 +162,7 @@ def test_http_get_other_error_codes() -> None: error_codes = [500, 502, 503, 504] for error_code in error_codes: - mock_response = MagicMock() - mock_response.status_code = error_code + mock_response = MockDownloadResponse(b"", status_code=error_code) mock_response.headers = {} out_file = io.BytesIO() @@ -194,12 +207,9 @@ def test_max_download_bytes_uses_env_override( env.setenv( LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, str(EXPECTED_DOWNLOAD_BYTES_OVERRIDE) ) - reloaded_download_lt = importlib.reload(download_lt) + importlib.reload(download_lt) - assert ( - reloaded_download_lt.MAX_DOWNLOAD_BYTES - == EXPECTED_DOWNLOAD_BYTES_OVERRIDE - ) + assert download_lt.MAX_DOWNLOAD_BYTES == EXPECTED_DOWNLOAD_BYTES_OVERRIDE finally: importlib.reload(download_lt) @@ -274,9 +284,9 @@ def test_latest_snapshot_uses_latest_download_url_and_current_date( "https://example.test/snapshots/", ) - with patch("language_tool_python.download_lt.datetime") as datetime_mock: - datetime_mock.now.return_value.strftime.return_value = "20240514" - local_language_tool = LocalLanguageTool.from_version_name("latest") + FixedDatetime.current_datetime = datetime(2024, 5, 14, tzinfo=timezone.utc) + monkeypatch.setattr(download_lt, "datetime", FixedDatetime) + local_language_tool = LocalLanguageTool.from_version_name("latest") assert local_language_tool.version_name == "20240514" assert ( @@ -475,7 +485,11 @@ def test_snapshot_download_renames_archive_root_to_requested_date( {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"}, ) local_language_tool = LocalLanguageTool.from_version_name(requested_snapshot) - monkeypatch.setattr(download_lt, "confirm_java_compatibility", lambda _: None) + monkeypatch.setattr( + download_lt, + "confirm_java_compatibility", + skip_java_compatibility_check, + ) with ( workspace_temp_dir() as temp_dir, @@ -510,10 +524,14 @@ def test_latest_snapshot_download_renames_archive_root_to_current_date( payload = make_zip_payload( {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"}, ) - with patch("language_tool_python.download_lt.datetime") as datetime_mock: - datetime_mock.now.return_value.strftime.return_value = current_snapshot_date - local_language_tool = LocalLanguageTool.from_version_name("latest") - monkeypatch.setattr(download_lt, "confirm_java_compatibility", lambda _: None) + FixedDatetime.current_datetime = datetime(2024, 5, 14, tzinfo=timezone.utc) + monkeypatch.setattr(download_lt, "datetime", FixedDatetime) + local_language_tool = LocalLanguageTool.from_version_name("latest") + monkeypatch.setattr( + download_lt, + "confirm_java_compatibility", + skip_java_compatibility_check, + ) with ( workspace_temp_dir() as temp_dir, diff --git a/tests/test_match.py b/tests/test_match.py index 448ccf1..a81713d 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -1,6 +1,6 @@ """Tests for the Match functionality of LanguageTool.""" -from typing import cast +from typing import TypedDict import language_tool_python @@ -8,6 +8,21 @@ EXPECTED_CORRECTED_MATCH_COUNT = 4 +class ExpectedMatch(TypedDict): + """Expected values for a LanguageTool match.""" + + rule_id: str + message: str + replacements: list[str] + offset_in_context: int + context: str + offset: int + error_length: int + category: str + rule_issue_type: str + sentence: str + + def test_langtool_load() -> None: """Test the basic functionality of LanguageTool and Match object attributes. @@ -22,7 +37,7 @@ def test_langtool_load() -> None: with language_tool_python.LanguageTool("en-US") as tool: matches = tool.check("ain't nothin but a thang") - expected_matches: list[dict[str, str | list[str] | int]] = [ + expected_matches: list[ExpectedMatch] = [ { "rule_id": "UPPERCASE_SENTENCE_START", "message": "This sentence does not start with an uppercase letter.", @@ -80,29 +95,20 @@ def test_langtool_load() -> None: assert len(matches) == len(expected_matches) for match_i, match in enumerate(matches): assert isinstance(match, language_tool_python.Match) - for key in [ - "rule_id", - "message", - "offset_in_context", - "context", - "offset", - "error_length", - "category", - "rule_issue_type", - "sentence", - ]: - assert expected_matches[match_i][key] == getattr(match, key) + expected_match = expected_matches[match_i] + assert expected_match["rule_id"] == match.rule_id + assert expected_match["message"] == match.message + assert expected_match["offset_in_context"] == match.offset_in_context + assert expected_match["context"] == match.context + assert expected_match["offset"] == match.offset + assert expected_match["error_length"] == match.error_length + assert expected_match["category"] == match.category + assert expected_match["rule_issue_type"] == match.rule_issue_type + assert expected_match["sentence"] == match.sentence # For replacements we allow some flexibility in the order # of the suggestions depending on the version of LT. - for key in [ - "replacements", - ]: - expected_replacements = cast( - "list[str]", - expected_matches[match_i][key], - ) - assert set(expected_replacements) == set(getattr(match, key)) + assert set(expected_match["replacements"]) == set(match.replacements) def test_match() -> None: diff --git a/tests/test_safe_zip.py b/tests/test_safe_zip.py index 7df4699..a003ccb 100644 --- a/tests/test_safe_zip.py +++ b/tests/test_safe_zip.py @@ -128,8 +128,8 @@ def test_safe_zip_limits_use_env_overrides( str(EXPECTED_MAX_TOTAL_COMPRESSION_RATIO), ) - reloaded_safe_zip = importlib.reload(safe_zip) - limits = reloaded_safe_zip.SafeZipLimits() + importlib.reload(safe_zip) + limits = safe_zip.SafeZipLimits() assert limits.max_archive_bytes == EXPECTED_MAX_ARCHIVE_BYTES assert limits.max_extracted_bytes == EXPECTED_MAX_EXTRACTED_BYTES