79672099

Date: 2025-06-19 13:12:59
Score: 1.5
Natty:
Report link

Here is an expanded version of @Mikko Ohtamaa's answer, fixing a couple bugs there* and adding checks for relationships, nullable columns and foreign keys. If you are looking for the main is_sane_database function, it is on the bottom.

*basically, it assumed that all the models defined corresponds directly to a table in the database with all the columns matching, but a model can be from a view, or multiple tables joined together (I encountered this through inheritance). Type checkers also complained a little bit.

from __future__ import annotations

import logging
from typing import Any, cast

from sqlalchemy import Engine, Inspector, Table, inspect, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.orm import ColumnProperty, DeclarativeBase, Mapper, RelationshipProperty

# noinspection PyProtectedMember
from sqlalchemy.orm.clsregistry import ClsRegistryToken, _ModuleMarker  # pyright: ignore[reportPrivateUsage]

logger = logging.getLogger(__name__)


# Handle some common type variations
type_mapping = {
    "integer": ["int", "integer", "int4"],
    "bigint": ["bigint", "int8"],
    "smallint": ["smallint", "int2"],
    "string": ["string", "varchar", "text"],
    "boolean": ["boolean", "bool"],
    "float": ["float", "real", "float4"],
    "double": ["double", "float8"],
    "json": ["json", "jsonb"],
}


def normalize_type(type_name: str) -> str:
    for base_type, variants in type_mapping.items():
        if any(variant in type_name for variant in variants):
            return base_type
    return type_name


class DatabaseSchema:
    """A class to hold database schema information."""

    def __init__(self, inspector: Inspector):
        logger.info("Getting table names from database %s", inspector.engine.url)
        self.tables = inspector.get_table_names()
        self.columns: dict[str, dict[str, Any]] = {}

        for table in self.tables:
            logging.info("Loading information from table %s", table)
            self.columns[table] = {c["name"]: c for c in inspector.get_columns(table)}


def check_relationship_property(
    column_prop: RelationshipProperty, schema: DatabaseSchema, klass: type[DeclarativeBase], engine: Engine
) -> bool:
    """Check if a relationship property is valid."""

    errors = False

    if column_prop.secondary is not None:
        # Additional checks for many-to-many relationships
        if not isinstance(column_prop.secondary, Table):
            logger.info(
                "Skipping relationship %s in model %s because secondary is not a Table object", column_prop.key, klass
            )
            return errors

        # Check secondary table exists
        if column_prop.secondary.name not in schema.tables:
            logger.error(
                "Model %s declares many-to-many relationship %s with secondary table %s which does not exist in database %s",
                klass,
                column_prop.key,
                column_prop.secondary.name,
                engine.url,
            )
            errors = True

    if not isinstance(column_prop.target, Table):
        logger.info("Skipping relationship %s in model %s because target is not a Table object", column_prop.key, klass)
        return errors

    target_table = column_prop.target.name
    if target_table not in schema.tables:
        logger.error(
            "Model %s declares relationship %s to table %s which does not exist in database %s",
            klass,
            column_prop.key,
            target_table,
            engine.url,
        )
        errors = True

    return errors


def check_column_property(
    column_prop: ColumnProperty, schema: DatabaseSchema, klass: type[DeclarativeBase], engine: Engine
) -> bool:
    """Check if a column property is valid."""
    # TODO: unique constraints
    errors = False

    # We cannot assume that all columns of the model are actual from that model itself, because it may inherit from another model.
    # So the following line is wrong. Instead, we need to get the table from the column itself.
    # table = klass.__tablename__

    for column in column_prop.columns:
        if not column.table._is_table:
            logger.info(
                "Skipping column %s in model %s because it originates from a non-table object (%s)",
                type(column.table).__name__,
            )
            continue
        else:
            assert isinstance(column.table, Table), "Expected column.table to be a Table instance"
            table = column.table.name
        # Check column exists
        if column.key not in schema.columns[table]:
            logger.error(
                "Model %s declares column %s which does not exist in database %s",
                klass,
                column.key,
                engine.url,
            )
            errors = True
            continue

        # Check column type
        db_column = schema.columns[table][column.key]
        model_type = column.type
        db_type = db_column["type"]

        # Compare type names, handling some common type variations
        model_type_name = str(model_type).lower()
        db_type_name = str(db_type).lower()

        if normalize_type(model_type_name) != normalize_type(db_type_name):
            logger.error(
                "Model %s column %s has type %s but database has type %s",
                klass,
                column.key,
                model_type,
                db_type,
            )
            errors = True

        # Check foreign key constraints
        if column.foreign_keys:
            for fk in column.foreign_keys:
                target_table = fk.column.table.name
                if target_table not in schema.tables:
                    logger.error(
                        "Model %s declares foreign key %s to table %s which does not exist in database %s",
                        klass,
                        column.key,
                        target_table,
                        engine.url,
                    )
                    errors = True
                else:
                    if fk.column.key not in schema.columns[target_table]:
                        logger.error(
                            "Model %s declares foreign key %s to column %s in table %s which does not exist in database %s",
                            klass,
                            column.key,
                            fk.column.key,
                            target_table,
                            engine.url,
                        )
                        errors = True

        # Check if the column is nullable
        if not column.nullable and db_column["nullable"]:
            logger.error(
                "Model %s declares column %s as non-nullable but database has it as nullable",
                klass,
                column.key,
            )
            errors = True

        if column.nullable and not db_column["nullable"]:
            logger.error(
                "Model %s declares column %s as nullable but database has it as non-nullable",
                klass,
                column.key,
            )
            errors = True

    return errors


def is_sane_database(base_cls: type[DeclarativeBase], engine: Engine) -> bool:
    """Check whether the current database matches the models declared in model base.

    Checks that:
    * All tables exist with all columns
    * Column types match between model and database
    * All relationships exist and are properly configured

    Args:
        base_cls (type[DeclarativeBase]): The SQLAlchemy declarative base class containing the models to check.
        engine: The SQLAlchemy engine or connection to the database.

    Returns:
        bool: True if all declared models have corresponding tables, columns, and relationships.

    Raises:
        TypeError: If the provided engine is an AsyncEngine instead of a synchronous Engine.

    References:
        https://stackoverflow.com/questions/30428639/check-database-schema-matches-sqlalchemy-models-on-application-startup
    """
    if isinstance(engine, AsyncEngine):
        raise TypeError("The engine must be a synchronous SQLAlchemy Engine, not an AsyncEngine.")

    logger.debug("starting validation")
    inspector = inspect(engine)
    schema = DatabaseSchema(inspector)

    # Run an empty query to ensure the connection is valid and all the models are defined correctly.
    # If this doesn't work, all queries will fail later anyway, so we don't suppress errors raised here.
    with engine.connect() as conn:
        conn.execute(text("SELECT 1"))

    errors = False

    # Go through all SQLAlchemy models and do the following checks:
    # - Check if the table exists in the database
    # For each attribute in the model:
    #     If it is a relationship:
    #         - Check if the secondary table exists (if applicable)
    #         - Check if the target table exists
    #     If it is a column:
    #         - Check if the column exists in the table
    #         - Check if the column type matches the model type
    #         - Check if the foreign key constraints are valid
    #         - Check if the column is nullable
    #
    # noinspection PyProtectedMember
    for name, klass in base_cls.registry._class_registry.items():  # pyright: ignore[reportPrivateUsage]
        logger.debug("Checking model %s (%s)", name, klass)
        if isinstance(klass, _ModuleMarker):
            logger.debug("Skipping module marker %s", name)
            continue
        if isinstance(klass, ClsRegistryToken):
            logger.debug("Skipping ClsRegistryToken %s", name)
            continue
        if not issubclass(klass, DeclarativeBase):
            logger.warning(
                "Cannot determine whether %s is actually a model because it is not a subclass of DeclarativeBase. "
                "If you use the declarative_base(), it dynamically generates a new class that cannot be determined."
                "We are assuming it is a model, but this may not be the case.",
                klass,
            )
            klass = cast(type[DeclarativeBase], klass)

        table: str = getattr(klass, "__tablename__")
        if not table:
            logger.error("Model %s does not have a __tablename__ attribute", klass)
            errors = True
            continue
        if table not in schema.tables:
            logger.error("Model %s declares table %s which does not exist in database %s", klass, table, engine.url)
            errors = True
            continue

        mapper = inspect(klass)
        assert isinstance(mapper, Mapper), "Expected mapper to be an instance of Mapper (uncertain)"

        try:  # If any error occurs during inspection, it will be caught, and errors will be set to True
            for column_prop in mapper.attrs:
                if isinstance(column_prop, RelationshipProperty):
                    if check_relationship_property(column_prop, schema, klass, engine):
                        errors = True
                elif isinstance(column_prop, ColumnProperty):
                    if check_column_property(column_prop, schema, klass, engine):
                        errors = True
                else:
                    logging.info(
                        "Encountered unexpected property %s in model %s with type %s",
                        column_prop.key,
                        klass.__name__,
                        type(column_prop),
                    )

        except SQLAlchemyError as e:
            logger.error("Error inspecting model %s: %s", klass.__name__, e)
            errors = True

    return not errors
Reasons:
  • Blacklisted phrase (1): stackoverflow
  • Long answer (-1):
  • Has code block (-0.5):
  • User mentioned (1): @Mikko
  • Low reputation (1):
Posted by: papeto