Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ optional-dependencies.testing = [
]
optional-dependencies.typing = [
"mypy>=1",
"typing-extensions; python_version<'3.10'",
]
urls.Changelog = "https://github.com/edgarrmondragon/sqlean-driver/blob/main/CHANGELOG.md"
urls.Documentation = "https://github.com/edgarrmondragon/sqlean-driver#readme"
Expand Down Expand Up @@ -244,6 +245,7 @@ lint.ignore = [
"ANN101", # missing-type-self
"ANN102", # missing-type-cls
"COM812", # missing-trailing-comma
"FIX002", # line-contains-todo
"ISC001", # single-line-implicit-string-concatenation
]
lint.per-file-ignores."tests/**/*" = [
Expand All @@ -252,6 +254,12 @@ lint.per-file-ignores."tests/**/*" = [
"S101",
"TID252",
]
lint.unfixable = [
"ERA",
"F401",
]
lint.flake8-annotations.allow-star-arg-any = true

lint.flake8-import-conventions.banned-from = [
"typing",
]
Expand All @@ -263,6 +271,9 @@ lint.isort.known-first-party = [
lint.isort.required-imports = [
"from __future__ import annotations",
]
lint.pep8-naming.ignore-names = [
"visit_*",
]
# Tests can use magic values, assertions, and relative imports
lint.pydocstyle.convention = "google"
lint.preview = true
Expand Down
68 changes: 68 additions & 0 deletions src/sqlean_driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,92 @@
from __future__ import annotations

import typing as t
import uuid
from importlib.metadata import version

import sqlalchemy.types as sqltypes
from sqlalchemy.dialects.sqlite.base import SQLiteTypeCompiler
from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
from sqlalchemy.sql.functions import GenericFunction

from sqlean_driver.custom_types import UUID

if t.TYPE_CHECKING:
from types import ModuleType

from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine

from sqlean_driver.custom_types import IPAddress, IPNetwork

__version__ = version(__package__)


class SQLeanTypeCompiler(SQLiteTypeCompiler):
"""A type compiler for SQLite that uses sqlean.py as the DBAPI."""

def visit_INET( # noqa: PLR6301
self,
type_: TypeEngine[IPAddress], # noqa: ARG002
**kw: t.Any, # noqa: ARG002
) -> str:
"""Visit an INET node."""
return "INET"

def visit_CIDR(self, type_: TypeEngine[IPNetwork], **kw: t.Any) -> str: # noqa: ARG002, PLR6301
"""Visit a CIDR nodes."""
return "CIDR"

def visit_UUID(self, type_: TypeEngine[uuid.UUID], **kw: t.Any) -> str: # noqa: ARG002, PLR6301
"""Visit a UUID node."""
return "UUID"


class uuid4(GenericFunction[uuid.UUID]): # noqa: N801
"""Generates a version 4 (random) UUID as a string.

Aliased as gen_random_uuid() for PostgreSQL compatibility.
"""

name = "uuid4"
type = UUID()
inherit_cache = True


class gen_random_uuid(uuid4): # noqa: N801
"""Generates a version 4 (random) UUID as a string."""

name = "gen_random_uuid"


class uuid_str(GenericFunction[uuid.UUID]): # noqa: N801
"""Converts a UUID `X` into a well-formed UUID string.

`X` can be either a string or a blob.
"""

name = "uuid_str"
type = UUID()
inherit_cache = True


class uuid_blob(GenericFunction[bytes]): # noqa: N801
"""Converts a UUID `X` into a well-formed UUID string.

`X` can be either a string or a blob.
"""

name = "uuid_blob"
type = sqltypes.BLOB()
inherit_cache = True


class SQLeanDialect(SQLiteDialect_pysqlite):
"""A dialect for SQLite that uses sqlean.py as the DBAPI."""

driver = "sqlean"
supports_statement_cache = True
type_compiler = SQLeanTypeCompiler

@classmethod
def dbapi(cls) -> ModuleType: # type: ignore[override]
Expand Down
186 changes: 186 additions & 0 deletions src/sqlean_driver/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Custom SQLAlchemy types."""

from __future__ import annotations

import ipaddress
import typing as t
import uuid

import sqlalchemy.types as sqltypes
from sqlalchemy.sql.functions import GenericFunction

if t.TYPE_CHECKING:
import sys

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType


IPAddress: TypeAlias = t.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
IPNetwork: TypeAlias = t.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]


def none_or_str(value: t.Any | None) -> str | None: # noqa: ANN401
"""Return the value or None."""
return str(value) if value is not None else None


def none_or_ip_interface(
value: t.Any | None, # noqa: ANN401
) -> ipaddress.IPv4Interface | ipaddress.IPv6Interface | None:
"""Return the value or None."""
return ipaddress.ip_interface(value) if value is not None else None


def none_or_ip_network(
value: t.Any | None, # noqa: ANN401
) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
"""Return the value or None."""
return ipaddress.ip_network(value) if value is not None else None


def none_or_uuid(
value: t.Any | None, # noqa: ANN401
) -> uuid.UUID | None:
"""Return the value or None."""
return uuid.UUID(value) if value is not None else None


class INET(sqltypes.TypeEngine[IPAddress]):
"""An INET type."""

__visit_name__ = "INET"

def bind_processor( # noqa: PLR6301
self,
_dialect: Dialect,
) -> _BindProcessorType[IPAddress] | None:
"""Return a bind processor."""
return none_or_str

def result_processor( # noqa: PLR6301
self,
_dialect: Dialect,
_coltype: object,
) -> _ResultProcessorType[IPAddress] | None:
"""Return a result processor."""
return none_or_ip_interface

# TODO(edgarrmondragon): Add missing type parameters:
# > sqltypes.Indexable.Comparator[IPAddress]
# > sqltypes.Concatenable.Comparator[IPAddress]
# https://github.com/edgarrmondragon/sqlean-driver/issues/37
class Comparator(
sqltypes.Indexable.Comparator, # type: ignore[type-arg]
sqltypes.Concatenable.Comparator, # type: ignore[type-arg]
):
"""Comparator for the INET type."""

def ipfamily(self) -> _IPAddrIPFamilyFunction:
"""Return the IP family."""
return _IPAddrIPFamilyFunction(self.expr)

def iphost(self) -> _IPAddrIPHostFunction:
"""Return the IP host."""
return _IPAddrIPHostFunction(self.expr)

def ipmasklen(self) -> _IPAddrIPMaskLenFunction:
"""Return the IP mask length."""
return _IPAddrIPMaskLenFunction(self.expr)

def ipnetwork(self) -> _IPAddrIPNetworkFunction:
"""Return the IP network."""
return _IPAddrIPNetworkFunction(self.expr)

def ipcontains(self, other: IPAddress | str) -> _IPAddrIPContainsFunction:
"""Return whether the IP address contains another IP address."""
return _IPAddrIPContainsFunction(self.expr, other)

comparator_factory = Comparator


class CIDR(sqltypes.TypeEngine[IPNetwork]):
"""A CIDR type."""

__visit_name__ = "CIDR"

def bind_processor( # noqa: PLR6301
self,
_dialect: Dialect,
) -> _BindProcessorType[IPNetwork] | None:
"""Return a bind processor."""
return none_or_str

def result_processor( # noqa: PLR6301
self,
_dialect: Dialect,
_coltype: object,
) -> _ResultProcessorType[IPNetwork] | None:
"""Return a result processor."""
return none_or_ip_network


class UUID(sqltypes.TypeEngine[uuid.UUID]):
"""A UUID type."""

__visit_name__ = "UUID"

def bind_processor( # noqa: PLR6301
self,
_dialect: Dialect,
) -> _BindProcessorType[uuid.UUID] | None:
"""Return a bind processor."""
return none_or_str

def result_processor( # noqa: PLR6301
self,
_dialect: Dialect,
_coltype: object,
) -> _ResultProcessorType[uuid.UUID] | None:
"""Return a result processor."""
return none_or_uuid


class _IPAddrIPFamilyFunction(GenericFunction[int]):
"""Returns the family of a specified IP address."""

name = "ipfamily"
type = sqltypes.Integer()
inherit_cache = True


class _IPAddrIPHostFunction(GenericFunction[str]):
"""Returns the host part of an IP address."""

name = "iphost"
type = sqltypes.String()
inherit_cache = True


class _IPAddrIPMaskLenFunction(GenericFunction[int]):
"""Returns the prefix length of an IP address."""

name = "ipmasklen"
type = sqltypes.Integer()
inherit_cache = True


class _IPAddrIPNetworkFunction(GenericFunction[IPNetwork]):
"""Returns the network part of an IP address."""

name = "ipnetwork"
type = CIDR()
inherit_cache = True


class _IPAddrIPContainsFunction(GenericFunction[bool]):
"""Returns whether an IP address contains another IP address."""

name = "ipcontains"
type = sqltypes.Boolean()
inherit_cache = True
Loading