Source code for scrachy.middleware.selenium

#  Copyright 2023 Reid Swanson.
#
#  This file is part of scrachy.
#
#  scrachy is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Lesser General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  scrachy is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Lesser General Public License for more details.
#
#   You should have received a copy of the GNU Lesser General Public License
#   along with scrachy.  If not, see <https://www.gnu.org/licenses/>.

"""
Middleware for processing requests with Selenium.
"""

from __future__ import annotations

# Python Modules
import logging
import math
import os
import queue

from struct import pack, unpack
from sys import executable
from typing import Any, Optional

# 3rd Party Modules
import pickle

from scrapy import Spider
from scrapy import signals
from scrapy.crawler import Crawler
from scrapy.http import HtmlResponse, Request
from scrapy.settings import Settings
from selenium.webdriver.remote.webdriver import WebDriver
from twisted.internet import reactor
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ProcessProtocol
from twisted.python import failure

# Project Modules
from scrachy import PROJECT_ROOT
from scrachy.http_ import SeleniumRequest
from scrachy.cli.webdriver_server import DEFAULT_BUFFER_SIZE, Message
from scrachy.settings.defaults.selenium import WebDriverName
from scrachy.utils.selenium import ShutdownRequest, initialize_driver
from scrachy.utils.selenium import process_request as process_request_helper

log = logging.getLogger(__name__)


[docs] class SeleniumMiddleware: """ A downloader middleware that uses a Selenium WebDriver to download the content and return an ``HtmlResponse`` if the incoming ``Response`` is an instance of :class:`~scrachy.http_.SeleniumRequest`. Otherwise, it returns ``None`` to let another downloader process it. """ webdriver_import_base = 'selenium.webdriver'
[docs] def __init__(self, settings: Settings, *args, **kwargs): super().__init__(*args, **kwargs) self.settings = settings self._driver = initialize_driver(self.driver_name, self.driver_options, self.driver_extensions)
# region Properties def get(self, name: str) -> Any: return self.settings.get(f'SCRACHY_SELENIUM_{name}') @property def driver_name(self) -> WebDriverName: return self.get('WEB_DRIVER') @property def driver_options(self) -> list[str]: return self.get('WEB_DRIVER_OPTIONS') @property def driver_extensions(self) -> list[str]: return self.get('WEB_DRIVER_EXTENSIONS') @property def driver(self) -> WebDriver: return self._driver # endregion Properties # region API @classmethod def from_crawler(cls, crawler: Crawler) -> SeleniumMiddleware: middleware = cls(crawler.settings) # See: https://docs.scrapy.org/en/latest/topics/signals.html crawler.signals.connect(middleware.spider_closed, signals.spider_closed) return middleware def process_request(self, request: Request, spider: Optional[Spider] = None) -> Optional[HtmlResponse]: return process_request_helper(self.driver, request) def spider_closed(self, spider: Optional[Spider] = None): self.driver.quit()
# endregion API
[docs] class AsyncSeleniumMiddleware: """ A downloader middleware that creates a pool of Selenium WebDrivers and sends any incoming :class:`SeleniumRequests <~scrachy.http_.SeleniumRequest>` to an available driver to be processed. """
[docs] def __init__(self, settings: Settings, *args, **kwargs): super().__init__(*args, **kwargs) self.settings = settings concurrent_requests: int = settings.getint('CONCURRENT_REQUESTS') log_file: str = settings.get('SCRACHY_SELENIUM_LOG_FILE') # Create a pool of drivers to increase the throughput. Since there # isn't actually any parallelism involved I don't think I have to # be all that careful with synchronization (e.g., locks). self.drivers = queue.Queue(maxsize=concurrent_requests) for driver in [WebDriverProtocol(i) for i in range(concurrent_requests)]: self.drivers.put(driver) args = ['python', '-m', 'scrachy.cli.webdriver_server'] args += ['-d', self.driver_name] args += [f'-o "{o}"' for o in self.driver_options] args += [f'-e "{e}"' for e in self.driver_extensions] if log_file: args += [f'-f "{log_file}"'] # noinspection PyUnresolvedReferences reactor.spawnProcess( driver, executable, args, path=PROJECT_ROOT, env=os.environ, )
# region Properties def get(self, name: str) -> Any: return self.settings.get(f'SCRACHY_SELENIUM_{name}') @property def driver_name(self) -> WebDriverName: return self.get('WEB_DRIVER') @property def driver_options(self) -> list[str]: return self.get('WEB_DRIVER_OPTIONS') @property def driver_extensions(self) -> list[str]: return self.get('WEB_DRIVER_EXTENSIONS') # endregion Properties # region API @classmethod def from_crawler(cls, crawler: Crawler) -> AsyncSeleniumMiddleware: middleware = cls(crawler.settings) # See: https://docs.scrapy.org/en/latest/topics/signals.html crawler.signals.connect(middleware.spider_closed, signals.spider_closed) return middleware def process_request(self, request: Request, spider: Optional[Spider] = None) -> Optional[Deferred[HtmlResponse]]: if not isinstance(request, SeleniumRequest): # Let some other downloader handle this request return None driver = self.drivers.get() d = driver.process_request(request) def enqueue_driver(r: HtmlResponse): self.drivers.put(driver) return r d.addCallback(enqueue_driver) return d def spider_closed(self, spider: Optional[Spider] = None): # Closing stdin should shut down the server while not self.drivers.empty(): driver = self.drivers.get(block=False) driver.shutdown()
# Uncommenting the following lines will allow any final messages # sent to stderr from the server just before exiting. # import time # from twisted.internet.threads import deferToThread # yield deferToThread(lambda: time.sleep(0.5)) # endregion API
[docs] class WebDriverProtocol(ProcessProtocol): # The number of bytes in the response message. response_header_size = 4
[docs] def __init__(self, id_: int, process_buffer_size: int = DEFAULT_BUFFER_SIZE): # An identifier for this process self.id = id_ # The size of the read buffer on the spawned process. We need to send # at lest this many bytes in order for the server's read buffer # to flush. Otherwise, the server will hang until it gets more data. self.process_buffer_size = process_buffer_size # Buffer to accumulate incoming messages self.buffer = b'' # The deferred object we will eventually return self.deferred_response: Optional[Deferred[HtmlResponse]] = None # This gets set once the shutdown message is sent and will be used # to prevent any further communication with the protocol. self.is_shutdown = False
# region Interface Methods
[docs] def connectionMade(self): log.debug(f"Connection made to: {self.id} with pid: {self.transport.pid}")
[docs] def outReceived(self, data: bytes): self.buffer += data self._extract_message()
[docs] def errReceived(self, data: bytes): log.error(f"Driver process error: {data.decode()}")
[docs] def inConnectionLost(self): log.debug(f"Lost stdin")
[docs] def outConnectionLost(self): log.debug(f"Lost stdout")
[docs] def errConnectionLost(self): log.debug(f"Lost stderr")
[docs] def processExited(self, reason: failure.Failure): log.info(f"Child process exited with exit code: {reason.value.exitCode}")
[docs] def processEnded(self, reason: failure.Failure): log.info(f"Child process ended: {reason.value.exitCode}")
# endregion Interface Methods def process_request(self, request: SeleniumRequest) -> Deferred[HtmlResponse]: if self.is_shutdown: raise ValueError("You cannot process requests after the server has been shut down.") # The original request has references to all sorts of unnecessary # and impossible to pickle objects. Just send over what we need. self._send_message( SeleniumRequest( url=request.url, wait_timeout=request.wait_timeout, wait_until=request.wait_until, screenshot=request.screenshot, script_executor=request.script_executor ) ) # We'll store the response here when it is ready. self.deferred_response = Deferred() return self.deferred_response def shutdown(self): self._send_message(ShutdownRequest()) self.transport.closeStdin() self.is_shutdown = True def _send_message(self, message: Message): message_data = pickle.dumps(message) # The number of bytes to encode the pickled data data_length = len(message_data) # The total number of bytes sent in the message (including the header # and padding) msg_length = self._get_message_length(data_length + 8) # The number of bytes to pad the message by. The sum of the header, # message, and padding should be an exact multiple of the process # buffer size. This is the difference between the total message length # and the data length and excluding the header. pad_length = (msg_length - data_length) - 8 data_field = pack('!I', data_length) msg_field = pack('!I', msg_length) self.transport.writeSequence([data_field, msg_field, message_data, b' ' * pad_length]) # noqa def _get_message_length(self, request_length: int) -> int: return self.process_buffer_size * math.ceil(request_length / self.process_buffer_size) def _extract_message(self): while len(self.buffer) >= self.response_header_size: msg_length = unpack('!I', self.buffer[:4])[0] if len(self.buffer) >= msg_length + 4: # Get the data from the buffer data = self.buffer[4:4+msg_length] # Remove the processed data from the buffer self.buffer = self.buffer[4+msg_length:] # Try to decode the message try: obj = pickle.loads(data) except pickle.PickleError as e: if self.deferred_response is not None: self.deferred_response.errback(e) else: log.error(f"There was a pickle error but the deferred response was not ready.") continue if self.deferred_response is None: log.error(f"Deferred response is not ready!") continue if not isinstance(obj, HtmlResponse): log.error(f"The message was not an HtmlResponse.") self.deferred_response.errback(obj) continue self.deferred_response.callback(obj) else: break # The message is not complete