Source code for scrachy.db.repositories

#  Copyright 2023 Reid Swanson.
#
#  This file is part of scrachy.
#
#  scrachy is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Lesser General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  scrachy is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Lesser General Public License for more details.
#
#   You should have received a copy of the GNU Lesser General Public License
#   along with scrachy.  If not, see <https://www.gnu.org/licenses/>.
# Loosely based on: https://hackernoon.com/building-a-to-do-list-app-with-python-data-access-layer-with-sqlalchemy

"""
The Data Access Layer.
"""

# Python Modules
import abc
import datetime
import logging

from typing import Generic, Iterable, Optional, Sequence, Type, TypeVar

# 3rd Party Modules
from sqlalchemy import insert, select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, load_only, selectinload

# Project Modules
from scrachy.db.base import Base
from scrachy.db.models import Response, ScrapeHistory
from scrachy.settings.defaults.storage import RetrievalMethod

BaseT = TypeVar('BaseT', bound=Base)


log = logging.getLogger(__name__)


[docs] class BaseRepository(abc.ABC, Generic[BaseT]):
[docs] def __init__(self, model: Type[BaseT], session: Session): self.model = model self.session = session self.dialect = session.bind.dialect.name self.upsert_fn = self._get_upsert_fn()
def find_all(self) -> Sequence[BaseT]: stmt = select(self.model) return self.session.scalars(stmt).all() def insert(self, obj: BaseT): self.session.add(obj) def insert_all(self, objs: Iterable[BaseT]): self.session.add_all(objs) def _get_upsert_fn(self): if self.dialect == 'sqlite': return sqlite_insert if self.dialect == 'postgresql': return pg_insert return insert
[docs] class ResponseRepository(BaseRepository[Response]):
[docs] def __init__(self, session: Session): super().__init__(Response, session=session)
def find_timestamp_by_fingerprint(self, fingerprint: bytes) -> Optional[datetime.datetime]: stmt = select( Response.scrape_timestamp ).where( Response.fingerprint == fingerprint ) return self.session.scalars(stmt).first() def find_by_fingerprint(self, fingerprint: bytes, retrieval_method: RetrievalMethod = 'full') -> Optional[Response]: if retrieval_method == 'minimal': return self._find_minimal(fingerprint) if retrieval_method == 'standard': return self._find_standard(fingerprint) if retrieval_method == 'full': return self._find_full(fingerprint) raise ValueError(f"Unknown retrieval method: {retrieval_method}") def upsert(self, response: Response, returning: bool = False) -> Response: # If the dialect is not postgresql or sqlite, we first need to # query for any existing items. If one exists we should perform # an update. Otherwise, we can use the upsert capabilities of the # specific dialects. if self.dialect in ('sqlite', 'postgresql'): return self._upsert_on_conflict(response, returning) return self._multi_query_upsert(response) # region Utility Methods def _find_minimal(self, fingerprint: bytes) -> Response: stmt = select( Response ).options( load_only( Response.body ) ).where( Response.fingerprint == fingerprint ) # This should be unique return self.session.scalars(stmt).one_or_none() def _find_standard(self, fingerprint: bytes) -> Response: stmt = ( select( Response ).options( load_only( Response.body, Response.headers, Response.status ) ).where( Response.fingerprint == fingerprint ) ) return self.session.scalars(stmt).one_or_none() def _find_full(self, fingerprint: bytes) -> Response: stmt = select( Response ).options( selectinload(Response.scrape_history) ).where( Response.fingerprint == fingerprint, ) return self.session.scalars(stmt).one_or_none() def _upsert_on_conflict(self, response: Response, returning: bool) -> Response: columns = response.__table__.columns.keys() stmt = self.upsert_fn( Response ).values( **{ c: getattr(response, c, None) for c in columns if c != 'id' } ) update_stmt = stmt.on_conflict_do_update( index_elements=[Response.fingerprint], set_={ c: stmt.excluded[c] for c in columns if c not in ('id', 'fingerprint') } ) if returning: update_stmt = update_stmt.returning(Response) result = self.session.scalars(update_stmt, execution_options={'populate_existing': True}) return result.one_or_none() self.session.execute(update_stmt, execution_options={'populate_existing': True}) return response def _multi_query_upsert(self, response: Response) -> Response: existing_response = self.find_by_fingerprint(response.fingerprint) update_columns = [c for c in response.__table__.columns.keys() if c not in ('id', 'fingerprint')] if existing_response is None: self.session.add(response) existing_response = response else: for col in update_columns: new_value = getattr(response, col) setattr(existing_response, col, new_value) return existing_response
# endregion Utility Methods
[docs] class ScrapeHistoryRepository(BaseRepository[ScrapeHistory]):
[docs] def __init__(self, session: Session): super().__init__(ScrapeHistory, session)