This repository was archived by the owner on Sep 4, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
118 lines (80 loc) · 3.3 KB
/
__init__.py
File metadata and controls
118 lines (80 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Dialect class for SQLAlchemy that uses sqlean.py as the DBAPI."""
from __future__ import annotations
import typing as t
import uuid
from importlib.metadata import version
import sqlalchemy.types as sqltypes
from sqlalchemy.dialects.sqlite.base import SQLiteTypeCompiler
from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
from sqlalchemy.sql.functions import GenericFunction
from sqlean_driver.custom_types import UUID
if t.TYPE_CHECKING:
from types import ModuleType
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine
from sqlean_driver.custom_types import IPAddress, IPNetwork
__version__ = version(__package__)
class SQLeanTypeCompiler(SQLiteTypeCompiler):
"""A type compiler for SQLite that uses sqlean.py as the DBAPI."""
def visit_INET( # noqa: PLR6301
self,
type_: TypeEngine[IPAddress], # noqa: ARG002
**kw: t.Any, # noqa: ARG002
) -> str:
"""Visit an INET node."""
return "INET"
def visit_CIDR(self, type_: TypeEngine[IPNetwork], **kw: t.Any) -> str: # noqa: ARG002, PLR6301
"""Visit a CIDR nodes."""
return "CIDR"
def visit_UUID(self, type_: TypeEngine[uuid.UUID], **kw: t.Any) -> str: # noqa: ARG002, PLR6301
"""Visit a UUID node."""
return "UUID"
class uuid4(GenericFunction[uuid.UUID]): # noqa: N801
"""Generates a version 4 (random) UUID as a string.
Aliased as gen_random_uuid() for PostgreSQL compatibility.
"""
name = "uuid4"
type = UUID()
inherit_cache = True
class gen_random_uuid(uuid4): # noqa: N801
"""Generates a version 4 (random) UUID as a string."""
name = "gen_random_uuid"
class uuid_str(GenericFunction[uuid.UUID]): # noqa: N801
"""Converts a UUID `X` into a well-formed UUID string.
`X` can be either a string or a blob.
"""
name = "uuid_str"
type = UUID()
inherit_cache = True
class uuid_blob(GenericFunction[bytes]): # noqa: N801
"""Converts a UUID `X` into a well-formed UUID string.
`X` can be either a string or a blob.
"""
name = "uuid_blob"
type = sqltypes.BLOB()
inherit_cache = True
class SQLeanDialect(SQLiteDialect_pysqlite):
"""A dialect for SQLite that uses sqlean.py as the DBAPI."""
driver = "sqlean"
supports_statement_cache = True
type_compiler = SQLeanTypeCompiler
@classmethod
def dbapi(cls) -> ModuleType: # type: ignore[override]
"""Return the DBAPI module.
NOTE: This is a legacy method that will stop being used by SQLAlchemy at some point.
"""
return cls.import_dbapi()
@classmethod
def import_dbapi(cls) -> ModuleType:
"""Return the DBAPI module."""
import sqlean # noqa: PLC0415
return sqlean # type: ignore[no-any-return]
def on_connect_url(self, url: URL) -> t.Callable[[t.Any], t.Any] | None:
"""Return a callable that will be executed on connect."""
query = url.query.get("extensions", ())
extensions = query if isinstance(query, tuple) else query.split(",")
if "all" in extensions:
self.dbapi.extensions.enable_all() # type: ignore[attr-defined]
else:
self.dbapi.extensions.enable(*extensions) # type: ignore[attr-defined]
return super().on_connect_url(url)