Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions odev/common/connectors/postgres.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""PostgreSQL connector."""

import os
import sys
import textwrap
from collections.abc import Mapping, MutableMapping, Sequence
from contextlib import contextmanager, nullcontext
Expand All @@ -24,6 +26,8 @@


DEFAULT_DATABASE = "postgres"
PG_VERSION_15 = 150000
COLLATION_WHITELIST = ["postgres", "odev", "template1"]


class Cursor(PsycopgCursor):
Expand All @@ -48,6 +52,12 @@ def transaction(self):
class PostgresConnector(Connector):
"""Connector class to interact with PostgreSQL."""

_checking_collation: ClassVar[bool] = False
"""Whether a collation check is currently in progress to avoid recursion."""

_has_collation_mismatch: ClassVar[bool] = False
"""Whether a collation mismatch has been detected in any connection."""

_fallback_database: ClassVar[str] = DEFAULT_DATABASE
"""The database to connect to if none is specified."""

Expand Down Expand Up @@ -80,10 +90,11 @@ def connect(self):
if self._connection is None:
self._connection = psycopg2.connect(database=self.database) # type: ignore [assignment]
self._connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)

if self.cr is None:
self.cr = Cursor(self._connection)

self._check_collation()

def disconnect(self):
"""Disconnect from the database engine."""
if self.cr is not None:
Expand Down Expand Up @@ -124,13 +135,15 @@ def query(
:param params: Additional parameters to pass to the cursor.
:param transaction: Whether to execute the query in a transaction.
"""
if self.__class__._has_collation_mismatch or (self._connection and self._connection.notices):
self._check_collation()

if self.cr is None:
raise ConnectorError("The cursor is not initialized, connect first", self)

query = textwrap.dedent(query).strip()
query_lower = query.lower()
is_select = query_lower.startswith("select")
expect_result = is_select or " returning " in query_lower

if is_select and not self.__class__._nocache and (self.database, query) in self.__class__._query_cache:
result = self.__class__._query_cache[(self.database, query)]
Expand Down Expand Up @@ -166,15 +179,15 @@ def signal_handler_cancel_statement(*args, **kwargs):
return False
raise error from error

result = expect_result and self.cr.fetchall()
result = self.cr.fetchall() if self.cr.description else True

if is_select and not self.__class__._nocache:
if DEBUG_SQL:
logger.debug(f"Caching PostgreSQL result for query against {self.database!r}:")
console.code(string.indent(str(result), 4), "python")
self.__class__._query_cache[(self.database, query)] = result

return result if expect_result else True
return result

def create_database(self, database: str, template: str | None = None) -> bool:
"""Create a database.
Expand Down Expand Up @@ -341,3 +354,101 @@ def create_column(self, table: str, column: str, attributes: str) -> bool:
"""
)
)

def _check_collation(self):
"""Check for collation mismatch by comparing system glibc version with DB version."""
if not self.odev or self.odev.in_test_mode or self.__class__._checking_collation:
return

self.__class__._checking_collation = True
try:
if self.database not in COLLATION_WHITELIST and not self.table_exists("ir_module_module"):
return

has_mismatch = self.__class__._has_collation_mismatch
if not has_mismatch and self._connection:
try:
with self._connection.cursor() as cr:
cr.execute("SELECT current_setting('server_version_num')::int")
if cr.fetchone()[0] < PG_VERSION_15:
return

cr.execute("SELECT datcollversion FROM pg_database WHERE datname = current_database()")
db_version = cr.fetchone()[0]

try:
sys_version = os.confstr("CS_GNU_LIBC_VERSION").split()[-1]
except (AttributeError, ValueError):
return

if db_version and sys_version and db_version != sys_version:
has_mismatch = True
except (psycopg2.Error, RuntimeError):
return

if has_mismatch:
self.__class__._has_collation_mismatch = False
logger.warning(
f"Database {string.stylize(self.database, 'repr.path')} has a collation version mismatch."
)

if self.odev.console.confirm(
"Do you want to refresh collation for all affected Odev databases?",
default=True,
):
self._refresh_all_collations()
logger.info("Collations refreshed. Please restart your command.")
sys.exit(0)
else:
logger.error("Collation mismatch detected. Aborting.")
sys.exit(1)
finally:
self.__class__._checking_collation = False

def _refresh_all_collations(self):
"""Refresh collation version for local Odev databases with mismatches."""
try:
sys_version = os.confstr("CS_GNU_LIBC_VERSION").split()[-1]
except (AttributeError, ValueError):
return

target_psql = self
if self.database != "postgres":
target_psql = PostgresConnector("postgres")
target_psql.connect()

try:
databases = target_psql.query(
f"""
SELECT datname
FROM pg_database
WHERE datistemplate = false
AND datcollversion IS NOT NULL
AND datcollversion <> {string.quote(sys_version, force_single=True)}
ORDER BY datname
"""
)

if not databases or isinstance(databases, bool):
return

for (db_name,) in databases:
if db_name not in COLLATION_WHITELIST:
try:
with PostgresConnector(db_name) as db_psql:
if not db_psql.table_exists("ir_module_module"):
continue
except (psycopg2.Error, RuntimeError):
continue

logger.info(f"Refreshing collation for database {string.stylize(db_name, 'repr.path')}...")
try:
target_psql.query(
f'ALTER DATABASE "{db_name}" REFRESH COLLATION VERSION',
transaction=False,
)
except (psycopg2.Error, RuntimeError) as e:
logger.error(f"Failed to refresh collation for {db_name}: {e}")
finally:
if target_psql is not self:
target_psql.disconnect()
Loading