diff --git a/pyproject.toml b/pyproject.toml index e037674..5edbe6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ optional-dependencies.testing = [ ] optional-dependencies.typing = [ "mypy>=1", + "typing-extensions; python_version<'3.10'", ] urls.Changelog = "https://github.com/edgarrmondragon/sqlean-driver/blob/main/CHANGELOG.md" urls.Documentation = "https://github.com/edgarrmondragon/sqlean-driver#readme" @@ -244,6 +245,7 @@ lint.ignore = [ "ANN101", # missing-type-self "ANN102", # missing-type-cls "COM812", # missing-trailing-comma + "FIX002", # line-contains-todo "ISC001", # single-line-implicit-string-concatenation ] lint.per-file-ignores."tests/**/*" = [ @@ -252,6 +254,12 @@ lint.per-file-ignores."tests/**/*" = [ "S101", "TID252", ] +lint.unfixable = [ + "ERA", + "F401", +] +lint.flake8-annotations.allow-star-arg-any = true + lint.flake8-import-conventions.banned-from = [ "typing", ] @@ -263,6 +271,9 @@ lint.isort.known-first-party = [ lint.isort.required-imports = [ "from __future__ import annotations", ] +lint.pep8-naming.ignore-names = [ + "visit_*", +] # Tests can use magic values, assertions, and relative imports lint.pydocstyle.convention = "google" lint.preview = true diff --git a/src/sqlean_driver/__init__.py b/src/sqlean_driver/__init__.py index 07fc722..87697f6 100644 --- a/src/sqlean_driver/__init__.py +++ b/src/sqlean_driver/__init__.py @@ -3,24 +3,92 @@ from __future__ import annotations import typing as t +import uuid from importlib.metadata import version +import sqlalchemy.types as sqltypes +from sqlalchemy.dialects.sqlite.base import SQLiteTypeCompiler from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite +from sqlalchemy.sql.functions import GenericFunction + +from sqlean_driver.custom_types import UUID if t.TYPE_CHECKING: from types import ModuleType from sqlalchemy.engine.url import URL + from sqlalchemy.sql.type_api import TypeEngine + from sqlean_driver.custom_types import IPAddress, IPNetwork __version__ = version(__package__) +class SQLeanTypeCompiler(SQLiteTypeCompiler): + """A type compiler for SQLite that uses sqlean.py as the DBAPI.""" + + def visit_INET( # noqa: PLR6301 + self, + type_: TypeEngine[IPAddress], # noqa: ARG002 + **kw: t.Any, # noqa: ARG002 + ) -> str: + """Visit an INET node.""" + return "INET" + + def visit_CIDR(self, type_: TypeEngine[IPNetwork], **kw: t.Any) -> str: # noqa: ARG002, PLR6301 + """Visit a CIDR nodes.""" + return "CIDR" + + def visit_UUID(self, type_: TypeEngine[uuid.UUID], **kw: t.Any) -> str: # noqa: ARG002, PLR6301 + """Visit a UUID node.""" + return "UUID" + + +class uuid4(GenericFunction[uuid.UUID]): # noqa: N801 + """Generates a version 4 (random) UUID as a string. + + Aliased as gen_random_uuid() for PostgreSQL compatibility. + """ + + name = "uuid4" + type = UUID() + inherit_cache = True + + +class gen_random_uuid(uuid4): # noqa: N801 + """Generates a version 4 (random) UUID as a string.""" + + name = "gen_random_uuid" + + +class uuid_str(GenericFunction[uuid.UUID]): # noqa: N801 + """Converts a UUID `X` into a well-formed UUID string. + + `X` can be either a string or a blob. + """ + + name = "uuid_str" + type = UUID() + inherit_cache = True + + +class uuid_blob(GenericFunction[bytes]): # noqa: N801 + """Converts a UUID `X` into a well-formed UUID string. + + `X` can be either a string or a blob. + """ + + name = "uuid_blob" + type = sqltypes.BLOB() + inherit_cache = True + + class SQLeanDialect(SQLiteDialect_pysqlite): """A dialect for SQLite that uses sqlean.py as the DBAPI.""" driver = "sqlean" supports_statement_cache = True + type_compiler = SQLeanTypeCompiler @classmethod def dbapi(cls) -> ModuleType: # type: ignore[override] diff --git a/src/sqlean_driver/custom_types.py b/src/sqlean_driver/custom_types.py new file mode 100644 index 0000000..830d25f --- /dev/null +++ b/src/sqlean_driver/custom_types.py @@ -0,0 +1,186 @@ +"""Custom SQLAlchemy types.""" + +from __future__ import annotations + +import ipaddress +import typing as t +import uuid + +import sqlalchemy.types as sqltypes +from sqlalchemy.sql.functions import GenericFunction + +if t.TYPE_CHECKING: + import sys + + if sys.version_info < (3, 10): + from typing_extensions import TypeAlias + else: + from typing import TypeAlias # noqa: ICN003 + + from sqlalchemy.engine.interfaces import Dialect + from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType + + +IPAddress: TypeAlias = t.Union[ipaddress.IPv4Address, ipaddress.IPv6Address] +IPNetwork: TypeAlias = t.Union[ipaddress.IPv4Network, ipaddress.IPv6Network] + + +def none_or_str(value: t.Any | None) -> str | None: # noqa: ANN401 + """Return the value or None.""" + return str(value) if value is not None else None + + +def none_or_ip_interface( + value: t.Any | None, # noqa: ANN401 +) -> ipaddress.IPv4Interface | ipaddress.IPv6Interface | None: + """Return the value or None.""" + return ipaddress.ip_interface(value) if value is not None else None + + +def none_or_ip_network( + value: t.Any | None, # noqa: ANN401 +) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None: + """Return the value or None.""" + return ipaddress.ip_network(value) if value is not None else None + + +def none_or_uuid( + value: t.Any | None, # noqa: ANN401 +) -> uuid.UUID | None: + """Return the value or None.""" + return uuid.UUID(value) if value is not None else None + + +class INET(sqltypes.TypeEngine[IPAddress]): + """An INET type.""" + + __visit_name__ = "INET" + + def bind_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + ) -> _BindProcessorType[IPAddress] | None: + """Return a bind processor.""" + return none_or_str + + def result_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + _coltype: object, + ) -> _ResultProcessorType[IPAddress] | None: + """Return a result processor.""" + return none_or_ip_interface + + # TODO(edgarrmondragon): Add missing type parameters: + # > sqltypes.Indexable.Comparator[IPAddress] + # > sqltypes.Concatenable.Comparator[IPAddress] + # https://github.com/edgarrmondragon/sqlean-driver/issues/37 + class Comparator( + sqltypes.Indexable.Comparator, # type: ignore[type-arg] + sqltypes.Concatenable.Comparator, # type: ignore[type-arg] + ): + """Comparator for the INET type.""" + + def ipfamily(self) -> _IPAddrIPFamilyFunction: + """Return the IP family.""" + return _IPAddrIPFamilyFunction(self.expr) + + def iphost(self) -> _IPAddrIPHostFunction: + """Return the IP host.""" + return _IPAddrIPHostFunction(self.expr) + + def ipmasklen(self) -> _IPAddrIPMaskLenFunction: + """Return the IP mask length.""" + return _IPAddrIPMaskLenFunction(self.expr) + + def ipnetwork(self) -> _IPAddrIPNetworkFunction: + """Return the IP network.""" + return _IPAddrIPNetworkFunction(self.expr) + + def ipcontains(self, other: IPAddress | str) -> _IPAddrIPContainsFunction: + """Return whether the IP address contains another IP address.""" + return _IPAddrIPContainsFunction(self.expr, other) + + comparator_factory = Comparator + + +class CIDR(sqltypes.TypeEngine[IPNetwork]): + """A CIDR type.""" + + __visit_name__ = "CIDR" + + def bind_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + ) -> _BindProcessorType[IPNetwork] | None: + """Return a bind processor.""" + return none_or_str + + def result_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + _coltype: object, + ) -> _ResultProcessorType[IPNetwork] | None: + """Return a result processor.""" + return none_or_ip_network + + +class UUID(sqltypes.TypeEngine[uuid.UUID]): + """A UUID type.""" + + __visit_name__ = "UUID" + + def bind_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + ) -> _BindProcessorType[uuid.UUID] | None: + """Return a bind processor.""" + return none_or_str + + def result_processor( # noqa: PLR6301 + self, + _dialect: Dialect, + _coltype: object, + ) -> _ResultProcessorType[uuid.UUID] | None: + """Return a result processor.""" + return none_or_uuid + + +class _IPAddrIPFamilyFunction(GenericFunction[int]): + """Returns the family of a specified IP address.""" + + name = "ipfamily" + type = sqltypes.Integer() + inherit_cache = True + + +class _IPAddrIPHostFunction(GenericFunction[str]): + """Returns the host part of an IP address.""" + + name = "iphost" + type = sqltypes.String() + inherit_cache = True + + +class _IPAddrIPMaskLenFunction(GenericFunction[int]): + """Returns the prefix length of an IP address.""" + + name = "ipmasklen" + type = sqltypes.Integer() + inherit_cache = True + + +class _IPAddrIPNetworkFunction(GenericFunction[IPNetwork]): + """Returns the network part of an IP address.""" + + name = "ipnetwork" + type = CIDR() + inherit_cache = True + + +class _IPAddrIPContainsFunction(GenericFunction[bool]): + """Returns whether an IP address contains another IP address.""" + + name = "ipcontains" + type = sqltypes.Boolean() + inherit_cache = True diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..360d43c --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,183 @@ +"""Test the custom types.""" + +from __future__ import annotations + +import sys +import typing as t +import uuid +from ipaddress import IPv4Interface, IPv4Network, IPv6Interface + +import pytest +from sqlalchemy import ( + Column, + Integer, + MetaData, + Table, + create_engine, + func, + select, +) + +from sqlean_driver.custom_types import CIDR, INET, UUID + +if t.TYPE_CHECKING: + from sqlalchemy import Select + +metadata = MetaData() +table = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("ip", INET), + Column("cidr", CIDR), + Column("uuid_col", UUID), +) + + +@pytest.mark.xfail( + sys.platform == "win32", + reason="'ipaddr' extension not available on Windows", +) +@pytest.mark.parametrize( + ("data", "query", "expected"), + [ + pytest.param( + [{"ip": None}], + select(table.c.ip, table.c.ip.ipnetwork()), + (None, None), + id="nullable", + ), + pytest.param( + [{"cidr": None}], + select(table.c.cidr), + (None,), + id="nullable_cidr", + ), + pytest.param( + [{"cidr": IPv4Network("192.168.16.3/32")}], + select(table.c.cidr), + (IPv4Network("192.168.16.3/32"),), + id="cidr", + ), + pytest.param( + [{"ip": IPv4Network("192.168.1.1")}], + select(func.ipfamily(table.c.ip), table.c.ip.ipfamily()), + (4, 4), + id="ipfamily", + ), + pytest.param( + [{"ip": IPv6Interface("2001:db8::123/64")}], + select(func.iphost(table.c.ip), table.c.ip.iphost()), + ("2001:db8::123", "2001:db8::123"), + id="iphost", + ), + pytest.param( + [{"ip": IPv4Interface("192.168.16.12/24")}], + select(func.ipmasklen(table.c.ip), table.c.ip.ipmasklen()), + (24, 24), + id="ipmasklen", + ), + pytest.param( + [{"ip": IPv4Interface("192.168.16.12/24")}], + select(func.ipnetwork(table.c.ip), table.c.ip.ipnetwork()), + ( + IPv4Network("192.168.16.0/24"), + IPv4Network("192.168.16.0/24"), + ), + id="ipnetwork", + ), + pytest.param( + [{"ip": IPv4Interface("192.168.16.0/24")}], + select( + func.ipcontains(table.c.ip, "192.168.16.3"), + table.c.ip.ipcontains("192.168.16.3"), + ), + (True, True), + id="ipcontains_lhs", + ), + pytest.param( + [{"ip": IPv4Interface("192.168.16.3")}], + select( + func.ipcontains("192.168.16.0/24", table.c.ip), + ), + (True,), + id="ipcontains_rhs", + ), + ], +) +def test_ipaddr_types( + data: list[dict[str, t.Any]], + query: Select[t.Any], + expected: tuple[t.Any, ...], +) -> None: + """Test that the types work.""" + engine = create_engine("sqlite+sqlean:///:memory:?extensions=ipaddr") + metadata.create_all(engine) + with engine.connect() as conn, conn.begin(): + conn.execute(table.insert(), data) + result = conn.execute(query) + assert result.fetchone() == expected + + +@pytest.mark.parametrize( + ("data", "query", "expected"), + [ + pytest.param( + [{"uuid_col": None}], + select(table.c.uuid_col), + (None,), + id="nullable", + ), + ], +) +def test_uuid_types( + data: list[dict[str, t.Any]], + query: Select[t.Any], + expected: tuple[t.Any, ...], +) -> None: + """Test that the types work.""" + engine = create_engine("sqlite+sqlean:///:memory:?extensions=uuid") + metadata.create_all(engine) + with engine.connect() as conn, conn.begin(): + conn.execute(table.insert(), data) + result = conn.execute(query) + assert result.fetchone() == expected + + +def test_function_uuid4() -> None: + """Test that the function works.""" + engine = create_engine("sqlite+sqlean:///:memory:?extensions=uuid") + metadata.create_all(engine) + with engine.connect() as conn: + result = conn.execute(select(func.uuid4())) + row = result.fetchone() + assert row is not None + assert isinstance(row[0], uuid.UUID) + + +def test_function_uuid_str() -> None: + """Test that the function works.""" + engine = create_engine("sqlite+sqlean:///:memory:?extensions=uuid") + metadata.create_all(engine) + with engine.connect() as conn: + result = conn.execute(select(func.uuid_str("8d144638-3baf-4901-a554-b541142c152b"))) + row = result.fetchone() + assert row is not None + assert row[0] == uuid.UUID("8d144638-3baf-4901-a554-b541142c152b") + + +def test_function_uuid_blob() -> None: + """Test that the function works.""" + engine = create_engine("sqlite+sqlean:///:memory:?extensions=uuid") + metadata.create_all(engine) + with engine.connect() as conn: + result = conn.execute( + select( + func.uuid_blob("8d144638-3baf-4901-a554-b541142c152b"), + func.uuid_blob(func.uuid4()), + ), + ) + row = result.fetchone() + assert row is not None + assert isinstance(row[0], bytes) + assert isinstance(row[1], bytes)