Source code for scrachy.db.base

#  Copyright 2020 Reid Swanson.
#
#  This file is part of scrachy.
#
#  scrachy is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Lesser General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  scrachy is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Lesser General Public License for more details.
#
#   You should have received a copy of the GNU Lesser General Public License
#   along with scrachy.  If not, see <https://www.gnu.org/licenses/>.

"""
The basic data types and classes required to define the SqlAlchemy models.
"""

from __future__ import annotations

# Python Modules
import json
import logging

from json import JSONDecodeError
from typing import Annotated, Any, Optional, Sequence

# 3rd Party Modules
from sqlalchemy import MetaData, inspect, BigInteger, LargeBinary, SmallInteger
from sqlalchemy.orm import DeclarativeBase, NO_VALUE, QueryableAttribute, declared_attr

# Project Modules
from scrachy.settings import PROJECT_SETTINGS
from scrachy.utils.sqltypes import TimeStampTZ
from scrachy.utils.strings import camel_to_snake

bigint = Annotated[BigInteger, 64]
binary = Annotated[LargeBinary, None]
smallint = Annotated[int, 16]
timestamp = Annotated[TimeStampTZ, None]

log = logging.getLogger(__name__)


schema = PROJECT_SETTINGS.get('SCRACHY_DB_SCHEMA')
schema_prefix = f"{schema}." if schema else ""


[docs] class Base(DeclarativeBase): metadata = MetaData( naming_convention={ "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s", }, schema=schema ) type_annotation_map = { bigint: BigInteger, binary: LargeBinary, smallint: SmallInteger, timestamp: TimeStampTZ(timezone=True), } # noinspection PyMethodParameters @declared_attr def __tablename__(cls): return camel_to_snake(cls.__name__) # Modified from https://stackoverflow.com/a/55749579/4971706 def __repr__(self) -> str: return json.dumps(self.to_dict(), indent=2, sort_keys=True, default=str) def __eq__(self, other: Base) -> bool: """ A model is equal to another if all of its columns are equal. :param other: :return: """ columns: Sequence[str] = self.__table__.columns.keys() for col in columns: this_value = getattr(self, col) try: that_value = getattr(other, col) except AttributeError: return False if this_value != that_value: return False return True def __hash__(self): d = {k: getattr(self, k) for k in sorted(self.__table__.columns.keys())} return hash(tuple(d.items())) def to_dict(self, hide_missing: bool = True, exclude_keys: Optional[Sequence[str]] = None): # Adapted from: https://medium.com/@alanhamlett/part-1-sqlalchemy-models-to-json-de398bc2ef47 # The only major change was to add a check to make sure the relationship # data is attached before trying to access it with getattr. path: str = self.__tablename__.lower() excluded: set[str] = set(exclude_keys) if exclude_keys else set() seen: set[Any] = {self} return self._to_dict(path, excluded, hide_missing, seen) def _to_dict(self, path: str, excluded: set[str], hide_missing: bool, seen: set[Any]) -> dict[str, Any]: columns: Sequence[str] = self.__table__.columns.keys() relationships: Sequence[str] = self.__mapper__.relationships.keys() properties: Sequence[str] = list(set(dir(self)) - set(columns) - set(relationships)) result: dict[str, Any] = dict() # The columns self._columns_to_dict(columns, path, excluded, result, hide_missing) self._relationships_to_dict(relationships, path, excluded, result, hide_missing, seen) self._properties_to_dict(properties, path, excluded, result, hide_missing) return result def _columns_to_dict( self, columns: Sequence[str], path: str, excluded: set[str], result: dict[str, Any], hide_missing: bool ): for key in columns: if self._is_private(key): continue # Some keys might be an SqlAlchemy subclass of str key = str(key) qualified_key = self._qualified_key(path, key) if qualified_key in excluded: continue value = getattr(self, key) if value is None and hide_missing: continue result[key] = value.hex() if isinstance(value, bytes) else value def _relationships_to_dict( self, relationships: Sequence[str], path: str, excluded: set[str], result: dict[str, Any], hide_missing: bool, seen: set[Any] ): for key in relationships: if self._is_private(key): continue qualified_key = self._qualified_key(path, key) if qualified_key in excluded: continue key = str(key) excluded.add(qualified_key) relationship = self.__mapper__.relationships[key] if relationship.uselist: items: list[Base] = getattr(self, key) if relationship.query_class is not None: if hasattr(items, "all"): items = items.all() result[key] = [i._to_dict(qualified_key.lower(), excluded, hide_missing, seen) for i in items] else: if relationship.query_class is not None or relationship.instrument_class is not None: state = inspect(self) loaded_value = state.attrs[key].loaded_value if loaded_value == NO_VALUE: result[key] = '[DETACHED]' else: item: Base = getattr(self, key) if item is not None and item not in seen: seen.add(item) result[key] = item._to_dict(qualified_key.lower(), excluded, hide_missing, seen) elif not hide_missing: result[key] = None else: value = getattr(self, key) if (value is not None or not hide_missing) and value not in seen: seen.add(value) result[key] = value def _properties_to_dict( self, properties: Sequence[str], path: str, excluded: set[str], result: dict[str, Any], hide_missing: bool ): for key in properties: if self._is_private(key): continue mycls = self.__class__ if not hasattr(mycls, key): continue attr = getattr(mycls, key) if not isinstance(attr, property) or isinstance(attr, QueryableAttribute): continue qualified_key = self._qualified_key(path, key) if qualified_key in excluded: continue key = str(key) value = getattr(self, key) if hasattr(value, '_to_dict'): result[key] = value._to_dict(qualified_key.lower(), excluded, hide_missing) else: try: result[key] = json.loads(json.dumps(value, sort_keys=True)) except (RecursionError, ValueError, TypeError, JSONDecodeError): pass @staticmethod def _qualified_key(path: str, key: str) -> str: return f"{path}.{key}" @staticmethod def _is_private(key: str) -> bool: return key.startswith('_')