diff --git a/odev/common/connectors/postgres.py b/odev/common/connectors/postgres.py index 92904bfa..7119c141 100644 --- a/odev/common/connectors/postgres.py +++ b/odev/common/connectors/postgres.py @@ -1,5 +1,7 @@ """PostgreSQL connector.""" +import os +import sys import textwrap from collections.abc import Mapping, MutableMapping, Sequence from contextlib import contextmanager, nullcontext @@ -24,6 +26,8 @@ DEFAULT_DATABASE = "postgres" +PG_VERSION_15 = 150000 +COLLATION_WHITELIST = ["postgres", "odev", "template1"] class Cursor(PsycopgCursor): @@ -48,6 +52,12 @@ def transaction(self): class PostgresConnector(Connector): """Connector class to interact with PostgreSQL.""" + _checking_collation: ClassVar[bool] = False + """Whether a collation check is currently in progress to avoid recursion.""" + + _has_collation_mismatch: ClassVar[bool] = False + """Whether a collation mismatch has been detected in any connection.""" + _fallback_database: ClassVar[str] = DEFAULT_DATABASE """The database to connect to if none is specified.""" @@ -80,10 +90,11 @@ def connect(self): if self._connection is None: self._connection = psycopg2.connect(database=self.database) # type: ignore [assignment] self._connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - if self.cr is None: self.cr = Cursor(self._connection) + self._check_collation() + def disconnect(self): """Disconnect from the database engine.""" if self.cr is not None: @@ -124,13 +135,15 @@ def query( :param params: Additional parameters to pass to the cursor. :param transaction: Whether to execute the query in a transaction. """ + if self.__class__._has_collation_mismatch or (self._connection and self._connection.notices): + self._check_collation() + if self.cr is None: raise ConnectorError("The cursor is not initialized, connect first", self) query = textwrap.dedent(query).strip() query_lower = query.lower() is_select = query_lower.startswith("select") - expect_result = is_select or " returning " in query_lower if is_select and not self.__class__._nocache and (self.database, query) in self.__class__._query_cache: result = self.__class__._query_cache[(self.database, query)] @@ -166,7 +179,7 @@ def signal_handler_cancel_statement(*args, **kwargs): return False raise error from error - result = expect_result and self.cr.fetchall() + result = self.cr.fetchall() if self.cr.description else True if is_select and not self.__class__._nocache: if DEBUG_SQL: @@ -174,7 +187,7 @@ def signal_handler_cancel_statement(*args, **kwargs): console.code(string.indent(str(result), 4), "python") self.__class__._query_cache[(self.database, query)] = result - return result if expect_result else True + return result def create_database(self, database: str, template: str | None = None) -> bool: """Create a database. @@ -341,3 +354,101 @@ def create_column(self, table: str, column: str, attributes: str) -> bool: """ ) ) + + def _check_collation(self): + """Check for collation mismatch by comparing system glibc version with DB version.""" + if not self.odev or self.odev.in_test_mode or self.__class__._checking_collation: + return + + self.__class__._checking_collation = True + try: + if self.database not in COLLATION_WHITELIST and not self.table_exists("ir_module_module"): + return + + has_mismatch = self.__class__._has_collation_mismatch + if not has_mismatch and self._connection: + try: + with self._connection.cursor() as cr: + cr.execute("SELECT current_setting('server_version_num')::int") + if cr.fetchone()[0] < PG_VERSION_15: + return + + cr.execute("SELECT datcollversion FROM pg_database WHERE datname = current_database()") + db_version = cr.fetchone()[0] + + try: + sys_version = os.confstr("CS_GNU_LIBC_VERSION").split()[-1] + except (AttributeError, ValueError): + return + + if db_version and sys_version and db_version != sys_version: + has_mismatch = True + except (psycopg2.Error, RuntimeError): + return + + if has_mismatch: + self.__class__._has_collation_mismatch = False + logger.warning( + f"Database {string.stylize(self.database, 'repr.path')} has a collation version mismatch." + ) + + if self.odev.console.confirm( + "Do you want to refresh collation for all affected Odev databases?", + default=True, + ): + self._refresh_all_collations() + logger.info("Collations refreshed. Please restart your command.") + sys.exit(0) + else: + logger.error("Collation mismatch detected. Aborting.") + sys.exit(1) + finally: + self.__class__._checking_collation = False + + def _refresh_all_collations(self): + """Refresh collation version for local Odev databases with mismatches.""" + try: + sys_version = os.confstr("CS_GNU_LIBC_VERSION").split()[-1] + except (AttributeError, ValueError): + return + + target_psql = self + if self.database != "postgres": + target_psql = PostgresConnector("postgres") + target_psql.connect() + + try: + databases = target_psql.query( + f""" + SELECT datname + FROM pg_database + WHERE datistemplate = false + AND datcollversion IS NOT NULL + AND datcollversion <> {string.quote(sys_version, force_single=True)} + ORDER BY datname + """ + ) + + if not databases or isinstance(databases, bool): + return + + for (db_name,) in databases: + if db_name not in COLLATION_WHITELIST: + try: + with PostgresConnector(db_name) as db_psql: + if not db_psql.table_exists("ir_module_module"): + continue + except (psycopg2.Error, RuntimeError): + continue + + logger.info(f"Refreshing collation for database {string.stylize(db_name, 'repr.path')}...") + try: + target_psql.query( + f'ALTER DATABASE "{db_name}" REFRESH COLLATION VERSION', + transaction=False, + ) + except (psycopg2.Error, RuntimeError) as e: + logger.error(f"Failed to refresh collation for {db_name}: {e}") + finally: + if target_psql is not self: + target_psql.disconnect()