diff --git a/README.new.md b/README.new.md index fc164fd..e5af69f 100644 --- a/README.new.md +++ b/README.new.md @@ -73,7 +73,30 @@ pip install -e . ```python from prove_shared import MongoDBHandler, AsyncAuth, Status -from prove_shared.mongo_handler import requestItemProcessing +from prove_shared.database.mongo import requestItemProcessing +``` + +### Package layout + +```text +prove-shared/ + pyproject.toml + config.yaml + src/ + prove_shared/ + __init__.py + auth.py + file_utils.py + logger.py + objects.py + queue_manager.py + wikidata_utils.py + database/ + __init__.py + interface.py # DataStore (ABC) — the contract + mongo.py # MongoDBHandler implementation + postgres.py # PostgreSQLHandler stub + orchestrator.py # get_database() + DatabaseOrchestrator ``` ## Setup Instructions diff --git a/prove-api/config.yaml b/prove-api/config.yaml index 0027578..22a0987 100644 --- a/prove-api/config.yaml +++ b/prove-api/config.yaml @@ -38,3 +38,30 @@ evidence_selection: n_top_sentences: 5 score_threshold: 0 token_size: 512 + +# ---------------------------------------------------------------------------- +# Database backend selection (read by prove_shared.database.get_database) +# ---------------------------------------------------------------------------- +# primary: "mongo" | "postgres" — the DB that owns the data +# fallback: "none" | "mongo" | "postgres" — secondary backend (optional) +# mode: "single" | "dual-write" — "dual-write" mirrors every write +# to the fallback during migration +# auto_fallback_on_read: if true, failed reads on primary +# retry against the fallback (off +# by default — silent fallbacks +# hide real outages) +# ---------------------------------------------------------------------------- +database: + primary: mongo + fallback: none + mode: single + auto_fallback_on_read: false + + mongo: + connection_string: "mongodb://localhost:27017/" + max_retries: 3 + + # TODO: populate when the Postgres migration begins. + postgres: + dsn: "postgresql://localhost/prove" + max_retries: 3 diff --git a/prove-api/custom_decorators.py b/prove-api/custom_decorators.py index 35071f1..d77fb08 100644 --- a/prove-api/custom_decorators.py +++ b/prove-api/custom_decorators.py @@ -1,15 +1,13 @@ # @repo: api -# @description: Flask decorators for request logging (@log_request) and API key authentication (@api_required); includes StatsDBHandler for usage tracking +# @description: Flask decorators for request logging (@log_request) and API key authentication (@api_required). Usage-logging now routes through MongoDBHandler.log_usage() — no more ad-hoc StatsDBHandler subclass. from datetime import datetime, timezone -from base64 import b64encode, b64decode +from base64 import b64decode from functools import wraps from flask import request import threading import time from typing import Any, Union -from pymongo import MongoClient - try: from utils_api import get_ip_location, logger from local_secrets import SOURCE, API_KEY, PRIVATE_KEY @@ -17,46 +15,30 @@ from api.utils_api import get_ip_location, logger from api.local_secrets import SOURCE, API_KEY, PRIVATE_KEY -from prove_shared.mongo_handler import MongoDBHandler +from prove_shared.database import get_database from prove_shared.auth import AsyncAuth -class StatsDBHandler(MongoDBHandler): - def __init__(self, connection_string="mongodb://localhost:27017/", max_retries=3): - super().__init__(connection_string, max_retries) - - def connect(self, max_retries: int, connection_string: str): - for attempt in range(self.max_retries): - try: - self.client = MongoClient(self.connection_string) - self.client.server_info() - self.db = self.client['service_usage'] - self.usage_collection = self.db['usage'] - print("Successfully connected to StatsDB") - return True - except Exception as e: - print(f"StatsDB connection attempt {attempt + 1} failed: {e}") - if attempt == self.max_retries - 1: - raise ConnectionError("Failed to connect to StatsDB") from e - time.sleep(5) # Wait before retry - - def close(self): - """Closes the MongoDB connection.""" - if self.client: - self.client.close() - print("MongoDB connection closed") - - def __enter__(self): - """Enables use with 'with' statement.""" - self.connect() - return self # Allows access to the instance in 'with' block - - def __exit__(self, exc_type, exc_value, traceback): - """Ensures the connection is closed when exiting 'with' block.""" - self.close() +# --------------------------------------------------------------------------- +# Shared database handle +# --------------------------------------------------------------------------- +# One backend instance per process is enough. `get_database()` reads the +# app's config.yaml and returns whichever implementation is configured +# (Mongo today, Postgres later, or an orchestrator that writes to both +# during migration). The per-request `with StatsDBHandler()` pattern used +# previously paid a connect cost on every HTTP hit — this avoids that. +_db = get_database() def log_request(func): + """ + Fire-and-forget usage logger for any API route. + + Writes the request metadata to the production usage DB on a background + thread so the actual response latency is unaffected. Logging failures are + swallowed inside the handler — a usage-log hiccup must never surface to + the caller as a 500. + """ @wraps(func) def wrapper(*args, **kwargs): method = request.method @@ -77,11 +59,12 @@ def wrapper(*args, **kwargs): end_time = time.monotonic() elapsed_time = end_time - start_time + # Only log in the production environment (SOURCE is set per-deploy). if SOURCE != 'server': return response threading.Thread( - target=log_usage_information, + target=_log_usage_information, args=(timestamp, method, url, headers, body, elapsed_time), daemon=True ).start() @@ -90,43 +73,46 @@ def wrapper(*args, **kwargs): return wrapper -def log_usage_information( +def _log_usage_information( timestamp: str, method: str, url: str, headers: dict[str, Any], body: dict[str, Any], - elapsed_time: float + elapsed_time: float, ) -> None: - try: - with StatsDBHandler() as db: - ip = headers.pop("X-Real-Ip", None) - headers.pop("X-Forwarded-For", None) - - if ip: - try: - headers["location"] = get_ip_location(ip) - except KeyError: - headers["X-Real-Ip"] = ip - logger.error(f"when retrieving location for {ip}") - except ConnectionError: - headers["X-Real-Ip"] = ip - logger.error("failed to retrieve location, check API") - - db.usage_collection.insert_one({ - "method": method, - "url": url, - "headers": headers, - "body": body, - "timestamp": timestamp, - "execution_time": elapsed_time - }) - except ConnectionError as e: - print(f"Failed to log usage information from StatsDB: {e}") - return + """ + Build a usage record and hand it to the database handler. + + The IP-geolocation enrichment can raise (KeyError on unknown IPs, + ConnectionError if the geo API is down). We handle those here so the + record is still written without location data rather than being dropped. + """ + ip = headers.pop("X-Real-Ip", None) + headers.pop("X-Forwarded-For", None) + + if ip: + try: + headers["location"] = get_ip_location(ip) + except KeyError: + headers["X-Real-Ip"] = ip + logger.error(f"when retrieving location for {ip}") + except ConnectionError: + headers["X-Real-Ip"] = ip + logger.error("failed to retrieve location, check API") + + _db.log_usage({ + "method": method, + "url": url, + "headers": headers, + "body": body, + "timestamp": timestamp, + "execution_time": elapsed_time, + }) def api_required(func): + """Reject requests that don't carry a valid AsyncAuth-signed API key.""" @wraps(func) def decorator(*args, **kwargs): if not request.json: @@ -136,7 +122,5 @@ def decorator(*args, **kwargs): api_key = b64decode(api_key) if api_key is None or not AsyncAuth.is_valid(api_key): return {"message": "Please provide a valid API key."}, 403 - else: - return func(*args, **kwargs) + return func(*args, **kwargs) return decorator - diff --git a/prove-api/functions.py b/prove-api/functions.py index d61cd4c..39ca558 100644 --- a/prove-api/functions.py +++ b/prove-api/functions.py @@ -1,7 +1,6 @@ # @repo: api # @description: Business logic layer for the API — aggregates and formats verification results, summaries, history, and queue stats from MongoDB from datetime import datetime -from functools import partial from collections import defaultdict from copy import deepcopy import json @@ -10,22 +9,24 @@ import uuid from typing import Dict, Any, List +import logging + import pandas as pd from plotly.subplots import make_subplots from plotly import graph_objects as go from plotly import io as pio -from pymongo import collection import yaml -import logging - -from prove_shared.mongo_handler import MongoDBHandler -from prove_shared.mongo_handler import requestItemProcessing as request_processing +from prove_shared.database import get_database +from prove_shared.database.mongo import requestItemProcessing as request_processing from prove_shared.objects import Status, HtmlContent, Entailment logger = logging.getLogger("prove_api") -mongo_handler = MongoDBHandler() +# Resolve the active backend once per process. `get_database()` reads +# `config.yaml` and returns either a bare MongoDBHandler, a PostgreSQLHandler, +# or a DatabaseOrchestrator wrapping both — callers never need to know which. +database_handler = get_database() # Params def load_config(config_path: str): @@ -113,24 +114,23 @@ def get_full_data(db_path, table_name): #1.1. check the aggregated results for an item (only recent one) def GetItem(target_id): try: - # Check status in MongoDB - mongo_status = mongo_handler.status_collection.find_one( - {'qid': target_id}, - sort=[('requested_timestamp', -1)] - ) - - if mongo_status: - task_id = mongo_status['task_id'] - - # 1. Get initial data structure from html_content collection - html_contents = list(mongo_handler.html_collection.find( - {'task_id': task_id}, - { - 'object_id': 1, 'property_id': 1, 'url': 1, + # Check status in MongoDB — delegated to handler so backend swaps cleanly later. + status_doc = database_handler.get_latest_status_by_qid(target_id) + + if status_doc: + task_id = status_doc['task_id'] + + # 1. Get initial data structure from html_content collection. + # Projection kept here (not moved into the handler default) because + # it's specific to this view's UI shape; the handler stays generic. + html_contents = database_handler.get_html_by_task_id( + task_id, + fields={ + 'object_id': 1, 'property_id': 1, 'url': 1, 'entity_label': 1, 'property_label': 1, 'object_label': 1, - 'reference_id': 1, 'lang': 1, 'status': 1, '_id': 0 - } - )) + 'reference_id': 1, 'lang': 1, 'status': 1, '_id': 0, + }, + ) # 2. Transform data structure with new keys and create triple result_items = [] @@ -154,11 +154,11 @@ def GetItem(target_id): result_items.append(item) continue - # 4. Query entailment results using temporary variables - entailment_results = list(mongo_handler.entailment_collection.find({ - 'task_id': task_id, - 'reference_id': temp_ref_id - })) + # 4. Query entailment results for this (task, reference) pair. + entailment_results = database_handler.get_entailments_by_task_and_reference( + task_id=task_id, + reference_id=temp_ref_id, + ) if entailment_results: # Group by result type and get highest score @@ -182,13 +182,13 @@ def GetItem(target_id): # Format status document formatted_status = { - 'qid': mongo_status['qid'], - 'task_id': mongo_status['task_id'], - 'status': mongo_status['status'], - 'algo_version': mongo_status['algo_version'], - 'start_time': mongo_status['requested_timestamp'].strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' - if isinstance(mongo_status['requested_timestamp'], datetime) - else mongo_status['requested_timestamp'] + 'qid': status_doc['qid'], + 'task_id': status_doc['task_id'], + 'status': status_doc['status'], + 'algo_version': status_doc['algo_version'], + 'start_time': status_doc['requested_timestamp'].strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + if isinstance(status_doc['requested_timestamp'], datetime) + else status_doc['requested_timestamp'] } return [formatted_status] + result_items @@ -203,10 +203,7 @@ def GetItem(target_id): def get_item(target_id: str, task_id: str = None, header: bool = True) -> List[Dict[str, Any]]: try: if task_id is None: - status = mongo_handler.status_collection.find_one( - {'qid': target_id}, - sort=[('requested_timestamp', -1)] - ) + status = database_handler.get_latest_status_by_qid(target_id) if not status: return get_item_from_sqlite(target_id) @@ -214,7 +211,8 @@ def get_item(target_id: str, task_id: str = None, header: bool = True) -> List[D status = Status(**status) task_id = status.task_id - html_contents = mongo_handler.html_collection.find({"task_id": task_id}) + # Fetch all HTML rows for this task (no projection — caller wants the full doc). + html_contents = database_handler.get_html_by_task_id(task_id) items = [] if html_contents: @@ -225,20 +223,13 @@ def get_item(target_id: str, task_id: str = None, header: bool = True) -> List[D if html_content.status == 200 ] - entailmments = mongo_handler.entailment_collection.aggregate([ - {"$match": { - "task_id": task_id, - "reference_id": {"$in": [item.reference_id for item in iterable_items]} - }}, - {"$sort": {"text_entailment_score": -1}}, - {"$group": { - "_id": { - "reference_id": "$reference_id", - "result": "$result" - }, - "docs": {"$push": "$$ROOT"} - }} - ]) + # Server-side aggregation: group top entailment scores per (reference, result). + # The pipeline lives in the handler; this call site only knows the inputs + # it needs (task_id + the references it actually cares about). + entailmments = database_handler.aggregate_entailments_by_task_id( + task_id=task_id, + reference_ids=[item.reference_id for item in iterable_items], + ) entailmments_by_ref = defaultdict(lambda: defaultdict(list)) for entailmment in entailmments: @@ -304,10 +295,10 @@ def get_item_from_sqlite(target_id): def CheckItemStatus(target_id): try: - # Check MongoDB status collection first - mongo_statuses = list(mongo_handler.status_collection.find({'qid': target_id})) + # Check MongoDB status collection first — unsorted; we scan in-memory below. + status_docs = database_handler.get_statuses_by_qid(target_id) - if mongo_statuses: + if status_docs: # Get the latest timestamp for each status, handling None values def get_latest_timestamp(status_doc): timestamps = [ @@ -327,7 +318,7 @@ def get_latest_timestamp(status_doc): valid_timestamps.append(ts) return max(valid_timestamps) if valid_timestamps else datetime.min - latest_status = max(mongo_statuses, key=get_latest_timestamp) + latest_status = max(status_docs, key=get_latest_timestamp) return { 'qid': latest_status['qid'], @@ -349,7 +340,8 @@ def get_latest_timestamp(status_doc): def get_summary(target_id: str, update: bool = False) -> dict[str, any]: - result = mongo_handler.summary_collection.find_one({'_id': target_id}) + # Cached summary lookup — None means "no summary computed yet". + result = database_handler.get_summary_by_id(target_id) summary = deepcopy(result) if result is None or update: @@ -365,10 +357,10 @@ def get_summary(target_id: str, update: bool = False) -> dict[str, any]: task_id = item.get("task_id") counter = pd.DataFrame(information[1:]) - total_claims = mongo_handler.stats_collection.find_one( - {'task_id': task_id, 'entity_id': target_id}, - {'total_claims': 1, '_id': 0} - ).get('total_claims', None) + # Default projection in the handler is {'total_claims': 1, '_id': 0}, + # matching every call site in this file — no projection needed here. + stats = database_handler.get_parser_stats_by_task_and_entity(task_id, target_id) + total_claims = (stats or {}).get('total_claims') version = item.get('algo_version', 'Not processed yet') last_update = item.get('start_time', 'Not processed yet') @@ -382,7 +374,8 @@ def get_summary(target_id: str, update: bool = False) -> dict[str, any]: if len(information) < 2 or information[1].get('Result') == 'No available URLs': result['status'] = 'No available URLs' result['proveScore'] = 1. - mongo_handler.summary_collection.insert_one({'_id': target_id, **result}) + # Upsert is idempotent and race-safe, unlike the old insert/update split. + database_handler.upsert_summary_by_id(target_id, result) return result refuting_count = counter[counter['result'] == 'REFUTES'].shape[0] @@ -402,10 +395,9 @@ def get_summary(target_id: str, update: bool = False) -> dict[str, any]: } }) - if not update or (update and summary is None): - mongo_handler.summary_collection.insert_one({'_id': target_id, **result}) - else: - mongo_handler.summary_collection.update_one({'_id': target_id}, {'$set': result}) + # Single upsert path — the previous insert-or-update branching was + # race-prone when two workers raced on the same missing summary. + database_handler.upsert_summary_by_id(target_id, result) else: result.pop('_id', None) @@ -432,14 +424,17 @@ def get_summary_from_item( total_counts = sum([refuting_count, inconclusive_count, supportive_count, irretrievable_count]) prove_score = (supportive_count - refuting_count) / total_counts if total_counts else None + # Parser-stats lookup via the handler. The handler returns None if the + # document doesn't exist, so guard with `or {}` before `.get`. + stats = database_handler.get_parser_stats_by_task_and_entity( + task_id=job.task_id, + entity_id=target_id, + ) return { "algoVersion": job.algo_version, "lastUpdate": job.requested_timestamp.isoformat(), "status": job.status, - "totalClaims": mongo_handler.stats_collection.find_one( - {'task_id': job.task_id, 'entity_id': target_id}, - {'total_claims': 1, '_id': 0} - ).get('total_claims', None), + "totalClaims": (stats or {}).get('total_claims'), "proveScore": prove_score, "count": { "refuting": refuting_count, @@ -449,7 +444,12 @@ def get_summary_from_item( } } - jobs = mongo_handler.status_collection.find({'qid': target_id}).sort("completed_timestamp", -1) + # All statuses for this QID, newest-first by completion time. + jobs = database_handler.get_statuses_by_qid( + target_id, + sort_by='completed_timestamp', + descending=True, + ) if jobs: jobs = [Status(**job) for job in jobs] jobs = [job for job in jobs if job.status == 'completed'] @@ -502,11 +502,8 @@ def comprehensive_results(target_id): task_id = first_item['task_id'] qid = first_item['qid'] - # Fetch total_claims from parser_stats collection - parser_stats = mongo_handler.stats_collection.find_one( - {'task_id': task_id, 'entity_id': qid}, - {'total_claims': 1, '_id': 0} - ) + # Fetch total_claims via the handler (default projection already matches). + parser_stats = database_handler.get_parser_stats_by_task_and_entity(task_id, qid) total_claims = parser_stats['total_claims'] if parser_stats else None @@ -556,9 +553,13 @@ def comprehensive_results(target_id): #2. status #2.1. checkQueue def checkQueue(): - in_queue = mongo_handler.user_collection.find( - {'status': 'in queue'}, - sort=[('requested_timestamp', 1)] + # Pending items on the user queue, oldest-first so the UI shows an honest + # FIFO view of what's currently waiting. + in_queue = database_handler.get_queue_items( + queue_name='user', + status='in queue', + sort_by='requested_timestamp', + ascending=True, ) items = [] @@ -611,17 +612,19 @@ def check_queue_status(conn, qid): def requestItemProcessing(qid: str): - """Request processing for a specific QID""" - save_function = partial( - mongo_handler.save_status, - queue=mongo_handler.user_collection, - ) + """ + Request processing for a specific QID. + + Routes through the shared `request_processing` helper (alias for + `requestItemProcessing` in prove-shared), passing the handler instance + and a queue *name* — no more callback plumbing. + """ return request_processing( qid=qid, + queue='user', + db=database_handler, algo_version=algo_version, - request_type="userRequested", - queue=mongo_handler.user_collection, - save_function=save_function + request_type='userRequested', ) diff --git a/prove-api/info.py b/prove-api/info.py index aa7d3c7..4b6cf1c 100644 --- a/prove-api/info.py +++ b/prove-api/info.py @@ -1,106 +1,57 @@ # @repo: api -# @description: Collects and aggregates API usage statistics from MongoDB for reporting and the dashboard +# @description: Offline script that aggregates API usage statistics from MongoDB and writes info.json. Reads go through MongoDBHandler — no raw pymongo here. from collections import defaultdict from tqdm import tqdm -import time import numpy as np -try: - from custom_decorators import StatsDBHandler - from utils_api import get_ip_location -except ImportError: - from api.custom_decorators import StatsDBHandler - from api.utils_api import get_ip_location +from prove_shared.database import get_database -from pymongo import MongoClient -from prove_shared.mongo_handler import MongoDBHandler +# --------------------------------------------------------------------------- +# Script entry point +# --------------------------------------------------------------------------- +# This module is only ever executed directly (`python info.py`) for offline +# reporting. Nothing imports it at runtime. We keep the top-level lean and +# push all work into `main()` so future callers (e.g. a scheduled job) can +# invoke it programmatically. +# --------------------------------------------------------------------------- +def main() -> None: + """ + Aggregate request-usage data from MongoDB and dump it to `info.json`. -class TMPStatsDBHandler(MongoDBHandler): - def __init__(self, connection_string="mongodb://localhost:27017/", max_retries=3): - super().__init__(connection_string, max_retries) + Reads prod usage records first, then enriches them with a second pass + against the dev/analysis mirror (`tmp_service_usage`). Both reads go + through the shared handler — this script no longer opens its own Mongo + connection, which was the last leaked `MongoClient` in the codebase. + """ + db = get_database() - def connect(self, max_retries, connection_string): - for attempt in range(self.max_retries): - try: - self.client = MongoClient(self.connection_string) - self.client.server_info() - self.db = self.client['tmp_service_usage'] - self.usage_collection = self.db['usage'] - print("Successfully connected to StatsDB") - return True - except Exception as e: - print(f"StatsDB connection attempt {attempt + 1} failed: {e}") - if attempt == self.max_retries - 1: - raise ConnectionError("Failed to connect to StatsDB") from e - time.sleep(5) # Wait before retry - - def close(self): - """Closes the MongoDB connection.""" - if self.client: - self.client.close() - print("MongoDB connection closed") - - def __enter__(self): - """Enables use with 'with' statement.""" - self.connect() - return self # Allows access to the instance in 'with' block - - def __exit__(self, exc_type, exc_value, traceback): - """Ensures the connection is closed when exiting 'with' block.""" - self.close() + # Prod records — everything the @log_request decorator has written. + prod_records = db.get_usage_records(use_dev_db=False) + # Dev/analysis mirror — used for running heavier queries without hitting prod. + _ = db.get_usage_records(use_dev_db=True) # TODO: wire dev records into reporting output if needed. + locations = _build_location_stats(prod_records) -if __name__ == "__main__": - storage = StatsDBHandler() - storage.connect(storage.max_retries, storage.connection_string) - documents = storage.usage_collection.find() - - tmp_storage = TMPStatsDBHandler() - tmp_storage.connect(tmp_storage.max_retries, tmp_storage.connection_string) - - def get_ip_location(ip: str) -> None: - from urllib.request import urlopen - import json - url = 'https://geolocation-db.com/json/e2bfd850-e6d9-11ef-bc40-012fd2b64c41/' + ip - # if res==None, check your internet connection - res = urlopen(url) - data = json.load(res) - - if "country_name" not in data.keys(): - raise KeyError() - - return { - "country_code": data.get("country_code", None), - "country_name": data.get("country_name", None), - "city": data.get("city", None), - "state": data.get("state", None), - "latitude": data.get("latitude", None), - "longitude": data.get("longitude", None), - } - - def get_entry_by_info(entry: dict, dictionary: dict[int, dict[str, any]]) -> dict[str, any]: - for key, value in dictionary.items(): - value.pop("hash", None) - if value == entry: - return key - return None - - def get_entry_by_hash(entry: int, dictionary: dict[int, dict[str, any]]) -> dict[str, any]: - for key, value in dictionary.items(): - if key == entry: - return value - return None - - count = [1 for _ in documents] - count = sum(count) - - locations = defaultdict(lambda: defaultdict(int)) - documents = storage.usage_collection.find() - for i, doc in enumerate(tqdm(documents, total=count)): + import json + with open("info.json", "w") as f: + json.dump(locations, f, indent=4) + + +def _build_location_stats(records: list[dict]) -> defaultdict: + """ + Reduce a list of usage records into the per-type / per-location stats + shape that `info.json` consumers expect. + + Pulled out of `main` so it's unit-testable without a live Mongo — give it + a list of dicts and it returns a fully-populated aggregation. + """ + locations: defaultdict = defaultdict(lambda: defaultdict(int)) + + for doc in tqdm(records, total=len(records)): try: - request_type = doc['url'].split("api")[-1].split("?")[0] - request_type = request_type.split("/")[-1] + # --- Request-type counters + execution-time stats ---------------- + request_type = doc['url'].split("api")[-1].split("?")[0].split("/")[-1] if request_type not in locations["request_type"]: locations["request_type"][request_type] = { "count": 0, @@ -108,13 +59,18 @@ def get_entry_by_hash(entry: int, dictionary: dict[int, dict[str, any]]) -> dict "min_execution_time": float('inf'), "max_execution_time": float('-inf'), } - locations["request_type"][request_type]["count"] += 1 - locations["request_type"][request_type]["execution_time"].append(doc.get("execution_time")) - if doc.get("execution_time") < locations["request_type"][request_type]["min_execution_time"]: - locations["request_type"][request_type]["min_execution_time"] = doc.get("execution_time") - if doc.get("execution_time") > locations["request_type"][request_type]["max_execution_time"]: - locations["request_type"][request_type]["max_execution_time"] = doc.get("execution_time") - + bucket = locations["request_type"][request_type] + bucket["count"] += 1 + + exec_time = doc.get("execution_time") + bucket["execution_time"].append(exec_time) + if exec_time is not None: + if exec_time < bucket["min_execution_time"]: + bucket["min_execution_time"] = exec_time + if exec_time > bucket["max_execution_time"]: + bucket["max_execution_time"] = exec_time + + # --- Referer / location / timestamp buckets ---------------------- headers = doc["headers"] headers['location'].pop('latitude', None) headers['location'].pop('longitude', None) @@ -129,30 +85,33 @@ def get_entry_by_hash(entry: int, dictionary: dict[int, dict[str, any]]) -> dict for key, value in headers['location'].items(): locations[key][value] += 1 - + locations['timestamp'][doc['timestamp'].split('T')[0]] += 1 - timestamp = doc['timestamp'].split('T')[1].split('.')[0] month_year = doc['timestamp'].split('T')[0].split('-') - if f"{month_year[1]}-{month_year[0]}" not in locations["month_year"]: - locations["month_year"][f"{month_year[1]}-{month_year[0]}"] = 0 - locations["month_year"][f"{month_year[1]}-{month_year[0]}"] += 1 + month_key = f"{month_year[1]}-{month_year[0]}" + if month_key not in locations["month_year"]: + locations["month_year"][month_key] = 0 + locations["month_year"][month_key] += 1 + # --- QID extraction --------------------------------------------- try: item = doc['url'].split("qid=")[-1] locations["qid"][item] += 1 except KeyError: pass - except AttributeError: - pass - except KeyError: - pass - - for key, value in locations["request_type"].items(): - value["average_execution_time"] = np.mean(value["execution_time"]) + except (AttributeError, KeyError): + # Individual malformed records shouldn't abort the whole pass. + continue + + # Collapse execution_time lists into a single mean per request type. + for value in locations["request_type"].values(): + value["average_execution_time"] = ( + float(np.mean(value["execution_time"])) if value["execution_time"] else None + ) del value["execution_time"] - import json - json_data = json.dumps(locations, indent=4) - with open("info.json", "w") as f: - json.dump(locations, f, indent=4) + return locations + +if __name__ == "__main__": + main() diff --git a/prove-api/queue_manager.py b/prove-api/queue_manager.py index 9955713..6d0e8e2 100644 --- a/prove-api/queue_manager.py +++ b/prove-api/queue_manager.py @@ -13,12 +13,15 @@ from api.local_secrets import MAX_CONNECTIONS from api.utils_api import logger -from prove_shared.mongo_handler import MongoDBHandler +from prove_shared.database import get_database class QueueManager: def __init__(self, queue: str): - self.mongodb = MongoDBHandler() + # Backend selected by config.yaml. Attribute access below (e.g. + # `.user_collection`) still assumes Mongo; it's the Phase 3 cleanup + # target once all queue references use names instead of Collection objects. + self.mongodb = get_database() self.queue: collection = getattr(self.mongodb, queue, None) if self.queue is None: logger.error(f"MongoDB has no queue with name {queue}") diff --git a/prove-processing/ProVe_heuristic_service.py b/prove-processing/ProVe_heuristic_service.py index ff9ebff..1c3f592 100644 --- a/prove-processing/ProVe_heuristic_service.py +++ b/prove-processing/ProVe_heuristic_service.py @@ -30,7 +30,7 @@ class HeuristicBasedService(ProVeService): heuristic (callable): The heuristic function to use for selecting QIDs. running (bool): A flag indicating whether the service is running. task_lock (Lock): A threading lock to ensure thread-safe operations. - mongo_handler (MongoDBHandler): An instance of MongoDBHandler for database operations. + database_handler (MongoDBHandler): An instance of MongoDBHandler for database operations. priority_queue (collection): The priority queue collection in MongoDB. secondary_queue (List[collection]): A list of secondary queue collections in MongoDB. """ @@ -92,21 +92,17 @@ def run(self) -> None: def verify_qid(self, qid: str) -> bool: """ - Verify if the QID is valid. + Return True iff `qid` isn't already present in any known queue. - Args: - qid (str): The QID to verify. - - Returns: - bool: True if the QID is valid and does not already exist in any of the secondary - queues, and priority queue. False otherwise. + Every existence check now routes through the handler, so the same + method works unchanged against Mongo today and Postgres tomorrow. """ - if self.priority_queue.find_one({'qid': qid}): + if self.database_handler.find_queue_item_by_qid(self.priority_queue, qid): logger.warning(f"{qid} already exists in priority queue {self.priority_queue.name}.") return False for queue in self.secondary_queue: - if queue.find_one({'qid': qid}): + if self.database_handler.find_queue_item_by_qid(queue, qid): logger.warning(f"QID {qid} already exists in secondary queue {queue.name}.") return False return True diff --git a/prove-processing/ProVe_main_service.py b/prove-processing/ProVe_main_service.py index 0a868a0..1c8c3b4 100644 --- a/prove-processing/ProVe_main_service.py +++ b/prove-processing/ProVe_main_service.py @@ -21,7 +21,7 @@ process_pagepile_list, ) import ProVe_main_process -from prove_shared.mongo_handler import MongoDBHandler +from prove_shared.database import get_database from prove_shared.local_secrets import ENDPOINT, API_KEY from prove_shared.auth import AsyncAuth @@ -51,7 +51,7 @@ class ProVeService: config (Dict[str, Any]): Configuration settings loaded from the YAML file. running (bool): A flag indicating whether the service is running. task_lock (Lock): A threading lock to ensure thread-safe operations. - mongo_handler (MongoDBHandler): An instance of MongoDBHandler for database operations. + database_handler (MongoDBHandler): An instance of MongoDBHandler for database operations. models (List[Module]): A list of initialized models for processing tasks. priority_queue (collection): The priority queue collection in MongoDB. secondary_queue (List[collection]): A list of secondary queue collections in MongoDB. @@ -158,12 +158,16 @@ def initialize_resources(self, model: bool = True) -> bool: for attempt in range(max_retries): try: logger.info("Attempting to initialize resources...") - self.mongo_handler = MongoDBHandler() - logger.info("WikiDataMongoDB connection successful") + # Backend chosen by config.yaml (Mongo today, Postgres later). + self.database_handler = get_database() + logger.info("Database connection successful") logger.info("Initializing queues...") - # Initialize priority queue - priority_queue = getattr(self.mongo_handler, self.priority_queue, None) + # TODO (Phase 3): this still expects a pymongo Collection via + # attribute access (e.g. `.user_collection`). Once every + # consumer uses queue *names* instead of Collection objects, + # this `getattr` dance can be replaced with `queue_name` strings. + priority_queue = getattr(self.database_handler, self.priority_queue, None) if priority_queue is None: exception = f"Priority queue '{self.priority_queue}'" exception += " not found in MongoDBHandler" @@ -172,7 +176,7 @@ def initialize_resources(self, model: bool = True) -> bool: # Initialize secondary queues secondary_queue = [ - getattr(self.mongo_handler, queue) for queue in self.secondary_queue + getattr(self.database_handler, queue) for queue in self.secondary_queue ] if len(secondary_queue) != len(self.secondary_queue): exception = "One or more secondary queues not found in MongoDBHandler: " @@ -205,8 +209,8 @@ def main_loop(self, status_dict: Dict[str, Any]) -> None: """ with self.task_lock: try: - self.mongo_handler.ensure_connection() - self.mongo_handler.save_status(status_dict) + self.database_handler.ensure_connection() + self.database_handler.save_status(status_dict) logger.info("Saved new status_dict into status") qid = status_dict['qid'] @@ -220,15 +224,15 @@ def main_loop(self, status_dict: Dict[str, Any]) -> None: entailment_results['task_id'] = task_id parser_stats['task_id'] = task_id - self.mongo_handler.save_html_content(html_df) - self.mongo_handler.save_entailment_results(entailment_results) - self.mongo_handler.save_parser_stats(parser_stats) + self.database_handler.save_html_content(html_df) + self.database_handler.save_entailment_results(entailment_results) + self.database_handler.save_parser_stats(parser_stats) status_dict['status'] = 'completed' status_dict['completed_timestamp'] = datetime.utcnow().strftime( '%Y-%m-%dT%H:%M:%S.%f' ) - self.mongo_handler.save_status(status_dict) + self.database_handler.save_status(status_dict) logger.info("Updated new status_dict into status") try: # TODO: This imports from prove-api (user service side). @@ -244,40 +248,36 @@ def main_loop(self, status_dict: Dict[str, Any]) -> None: logger.error(f"Error processing task {task_id}: {e}") status_dict['status'] = 'error' status_dict['error_message'] = str(e) - self.mongo_handler.save_status(status_dict) + self.database_handler.save_status(status_dict) def retry_processing(self, queue: collection) -> None: """ - Retry processing items in the queue that are stuck in 'processing' state. + Retry items in `queue` that are stuck in 'processing'. + + Reads and writes both route through the shared handler now, so this + method is identical for Mongo and (future) Postgres. The retry limit + is intentionally hard-coded here — it's a service-policy constant, + not a DB concern. Args: - queue (collection): The MongoDB collection representing the queue to check. + queue: The queue to sweep. May be a pymongo Collection (legacy + `self.priority_queue` / `self.secondary_queue` entries) — the + handler's `_resolve_queue` accepts both. """ retry_limit = 3 - # Find items that are in processing state - stuck_items = queue.find({ - 'status': 'processing' - }) + # Items currently stuck in 'processing' — candidates to retry or fail. + stuck_items = self.database_handler.get_queue_items(queue, status='processing') for item in stuck_items: - # Check the number of retries if item.get('retry_count', 0) < retry_limit: logger.info(f"Retrying QID {item['qid']}...") - # Increment the retry count - queue.update_one( - {'_id': item['_id']}, - {'$set': {'retry_count': item.get('retry_count', 0) + 1}} - ) - # Reprocess the item + # Atomic `$inc` inside the handler — no lost-update race. + self.database_handler.increment_retry_by_id(queue, item['_id']) self.main_loop(item) else: logger.error(f"QID {item['qid']} has reached the maximum retry limit.") - # Update the status to error if retry limit is reached - queue.update_one( - {'_id': item['_id']}, - {'$set': {'status': 'error', 'error_message': 'Max retry limit reached'}} - ) + self.database_handler.mark_queue_item_error_by_id(queue, item['_id']) def run(self): """ @@ -301,13 +301,13 @@ def run(self): while self.running: try: - self.mongo_handler.ensure_connection() + self.database_handler.ensure_connection() _id = self.get_next_request(self.priority_queue.name) logger.info(f"Next request {_id}") status_dict = {} if _id: - status_dict = self.mongo_handler.get_request_by_id(self.priority_queue, _id) + status_dict = self.database_handler.get_request_by_id(self.priority_queue, _id) if status_dict: logger.info(f"Processing request for QID: {status_dict['qid']}") @@ -319,7 +319,7 @@ def run(self): _id = self.get_next_request(queue.name) status_dict = {} if _id: - status_dict = self.mongo_handler.get_request_by_id(queue, _id) + status_dict = self.database_handler.get_request_by_id(queue, _id) logger.info(f"{_id}: {status_dict}") if status_dict: @@ -340,10 +340,19 @@ def run(self): sys.exit(1) def update_request(self, queue, status_dict, status): + """ + Flip the `status` field on the (task_id, qid) row in `queue`. + + Called from the main loop to mark a job 'completed' once processing + finishes. The handler does the write so this method stays backend- + agnostic. + """ logger.info(f"Updating {status_dict['qid']} for {queue.name}") - queue.update_one( - {'task_id': status_dict['task_id'], 'qid': status_dict['qid']}, - {'$set': {'status': status}} + self.database_handler.update_queue_status_by_task_and_qid( + queue_name=queue, + task_id=status_dict['task_id'], + qid=status_dict['qid'], + status=status, ) logger.info(f"Updated {status_dict['qid']} original request with {status}") diff --git a/prove-processing/background_processing.py b/prove-processing/background_processing.py index be1a16e..77e5741 100644 --- a/prove-processing/background_processing.py +++ b/prove-processing/background_processing.py @@ -8,7 +8,8 @@ import requests import yaml -from prove_shared.mongo_handler import MongoDBHandler, requestItemProcessing +from prove_shared.database import get_database +from prove_shared.database.mongo import requestItemProcessing logger = logging.getLogger("prove_processing") @@ -21,7 +22,9 @@ def load_config(config_path: str): config = load_config('config.yaml') algo_version = config['version']['algo_version'] -mongo_handler = MongoDBHandler() +# Backend selected by config.yaml. Mongo today; Postgres (or a dual-write +# orchestrator) when the migration flips the `database.primary` key. +database_handler = get_database() def fetch_qid_by_label(label): @@ -124,27 +127,42 @@ def process_top_viewed_items(project="en.wikipedia", access="all-access", limit= for idx, (title, views, qid) in enumerate(top_items, 1): logger.info(f"{idx}. Title: {title} - {views} views (QID: {qid})") - # Queue each item for processing - if qid: # Only queue if QID is found - result = requestItemProcessing(qid, 'top_viewed') + # Queue each item for processing on the random queue. + # Previously the second positional arg was `'top_viewed'` which + # silently slotted into the wrong parameter — fixed here by + # using keyword args and the semantic queue name. + if qid: + result = requestItemProcessing( + qid=qid, + queue='random', + db=database_handler, + request_type='top_viewed', + algo_version=algo_version, + ) logger.info(f" Queue status: {result}") else: logger.info("No articles found.") def process_pagepile_list(file_path='utils/pagepileList.txt'): """ - Process the QIDs from the pagepile list file and queue them for processing. - + Read QIDs from `file_path` and enqueue each on the random queue. + Args: file_path: The path to the pagepile list file. """ try: with open(file_path, 'r') as file: qids = file.read().splitlines() - + for qid in qids: - if qid: # Ensure the QID is not empty - result = requestItemProcessing(qid, 'pagepile_weekly_update') + if qid: + result = requestItemProcessing( + qid=qid, + queue='random', + db=database_handler, + request_type='pagepile_weekly_update', + algo_version=algo_version, + ) logger.info(f"Queued QID {qid} for processing: {result}") except Exception as e: logger.error(f"Error processing pagepile list: {e}") @@ -152,28 +170,28 @@ def process_pagepile_list(file_path='utils/pagepileList.txt'): def process_system_qid(qid: str) -> None: """ - Queue system QID for processing. + Queue a system-generated QID on the random queue. - Args: - qid: The QID to process. + The old implementation passed `save_function=random_collection.insert_one` + as a raw pymongo callback — the last leaked collection method in the + codebase. It's gone: the handler now owns the insert. Raises: ValueError: If the QID does not start with 'Q'. """ if not qid.startswith('Q'): try: - int(qid) # Check if the random QID is a valid integer + int(qid) # Confirm it's a valid integer we can prefix with 'Q'. qid = f"Q{qid}" except ValueError as e: raise ValueError("Generated QID does not start with 'Q'.") from e - - # Queue the random QID for processing + result = requestItemProcessing( qid=qid, - algo_version=algo_version, + queue='random', + db=database_handler, request_type='Random_processing', - queue=mongo_handler.random_collection, - save_function=mongo_handler.random_collection.insert_one + algo_version=algo_version, ) logger.info(f"Queued random QID {qid} for processing: {result}") diff --git a/prove-processing/config.yaml b/prove-processing/config.yaml index 166e488..d5f2d43 100644 --- a/prove-processing/config.yaml +++ b/prove-processing/config.yaml @@ -38,3 +38,30 @@ evidence_selection: n_top_sentences: 5 score_threshold: 0 token_size: 512 + +# ---------------------------------------------------------------------------- +# Database backend selection (read by prove_shared.database.get_database) +# ---------------------------------------------------------------------------- +# primary: "mongo" | "postgres" — the DB that owns the data +# fallback: "none" | "mongo" | "postgres" — secondary backend (optional) +# mode: "single" | "dual-write" — "dual-write" mirrors every write +# to the fallback during migration +# auto_fallback_on_read: if true, failed reads on primary +# retry against the fallback (off +# by default — silent fallbacks +# hide real outages) +# ---------------------------------------------------------------------------- +database: + primary: mongo + fallback: none + mode: single + auto_fallback_on_read: false + + mongo: + connection_string: "mongodb://localhost:27017/" + max_retries: 3 + + # TODO: populate when the Postgres migration begins. + postgres: + dsn: "postgresql://localhost/prove" + max_retries: 3 diff --git a/prove-shared/src/prove_shared/__init__.py b/prove-shared/src/prove_shared/__init__.py index 8df800e..cf136c0 100644 --- a/prove-shared/src/prove_shared/__init__.py +++ b/prove-shared/src/prove_shared/__init__.py @@ -1,5 +1,5 @@ from .auth import AsyncAuth -from .mongo_handler import MongoDBHandler, requestItemProcessing +from .database.mongo import MongoDBHandler, requestItemProcessing from .objects import Entailment, HtmlContent, Status from .queue_manager import QueueManager from .wikidata_utils import CachedWikidataAPI diff --git a/prove-shared/src/prove_shared/database/__init__.py b/prove-shared/src/prove_shared/database/__init__.py new file mode 100644 index 0000000..0082667 --- /dev/null +++ b/prove-shared/src/prove_shared/database/__init__.py @@ -0,0 +1,31 @@ +# @repo: shared +# @description: Public surface of the database subpackage. +""" +Database abstraction layer for ProVe. + +Application code should only import from this module: + + from prove_shared.database import get_database + + db = get_database() # returns a DataStore (ABC) implementation + db.get_latest_status_by_qid("Q42") + +Which backend `db` actually is (Mongo, Postgres, or an orchestrator wrapping +both for dual-write migrations) is decided by the `database:` block in the +app's `config.yaml`. See `orchestrator.py` for the config schema. +""" +from .interface import DataStore +from .orchestrator import ( + DatabaseOrchestrator, + get_database, + reset_cached_database, +) +from .postgres import PostgreSQLHandler + +__all__ = [ + "DataStore", + "DatabaseOrchestrator", + "PostgreSQLHandler", + "get_database", + "reset_cached_database", +] diff --git a/prove-shared/src/prove_shared/database/interface.py b/prove-shared/src/prove_shared/database/interface.py new file mode 100644 index 0000000..a50a12e --- /dev/null +++ b/prove-shared/src/prove_shared/database/interface.py @@ -0,0 +1,252 @@ +# @repo: shared +# @description: DataStore (ABC) — the database contract. Every backend (Mongo, Postgres) must implement this so the orchestrator can swap them at runtime. +""" +DataStore (ABC) — the single contract every backend implements. + +Why it exists: + We're migrating from MongoDB to PostgreSQL (MongoDB is not compliant with + Wikimedia's open-source requirements). The migration needs to be config- + driven, backward-compatible, and reversible — which means application + code must never couple to a specific backend. + + This contract is the boundary. Callers depend on `DataStore` (ABC); + concrete implementations (`MongoDBHandler`, `PostgreSQLHandler`) + implement it. The `DatabaseOrchestrator` wraps whichever implementation + the YAML config selects at runtime. + +Naming convention: + * `get__by_(...)` — single keyed lookup, returns Optional[dict] + * `get_(...)` / `get__by_(...)` — list return + * verb-first (`save_`, `log_`, `enqueue_`, `increment_`, `mark_`, `update_`) + for mutations + * Inputs are plain Python (str, dict, list). No BSON, no ObjectId leaks. + * Reads return Optional[dict] or List[dict] — never a pymongo Cursor and + never a SQLAlchemy ResultProxy — so the caller can't tell the backends + apart. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +import pandas as pd + + +# --------------------------------------------------------------------------- +# Type aliases +# --------------------------------------------------------------------------- +# `QueueRef` accepts either a short queue name ("user" / "random" / "status") +# — the preferred, backend-agnostic form — or a raw backend object (pymongo +# Collection today, SQLAlchemy Table tomorrow). The dual-acceptance keeps +# the migration smooth; Phase 3 can narrow this to `str` once every caller +# is off raw collection references. +QueueRef = Union[str, Any] + + +class DataStore(ABC): + """ + Contract every ProVe database backend must implement. + + New methods should only be added here when a caller genuinely needs a + new primitive — not to expose incidental pymongo or SQLAlchemy features. + Every method added here becomes a commitment the Postgres backend must + also satisfy. + """ + + # ------------------------------------------------------------------ # + # Connection lifecycle # + # ------------------------------------------------------------------ # + @abstractmethod + def ensure_connection(self, try_reconnect: bool = True) -> None: + """Verify the backend connection is alive; reconnect if permitted.""" + + # ================================================================== + # Reads + # ================================================================== + @abstractmethod + def get_latest_status_by_qid(self, qid: str) -> Optional[Dict[str, Any]]: + """Most recent status row for `qid` (largest requested_timestamp).""" + + @abstractmethod + def get_statuses_by_qid( + self, + qid: str, + sort_by: Optional[str] = None, + descending: bool = True, + ) -> List[Dict[str, Any]]: + """Every status row for `qid`, optionally sorted by `sort_by`.""" + + @abstractmethod + def get_html_by_task_id( + self, + task_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> List[Dict[str, Any]]: + """HTML-content rows for `task_id`. `fields` is an optional projection.""" + + @abstractmethod + def get_entailments_by_task_and_reference( + self, + task_id: str, + reference_id: str, + ) -> List[Dict[str, Any]]: + """Entailment rows for the (task_id, reference_id) pair.""" + + @abstractmethod + def aggregate_entailments_by_task_id( + self, + task_id: str, + reference_ids: List[str], + ) -> List[Dict[str, Any]]: + """ + Top entailments for `task_id`, restricted to `reference_ids`, grouped + by (reference_id, result). Mongo does this server-side with `$group`; + Postgres can do the same with window functions or a client-side fold. + """ + + @abstractmethod + def get_summary_by_id(self, target_id: str) -> Optional[Dict[str, Any]]: + """Cached summary for `target_id` (a QID), or None if uncomputed.""" + + @abstractmethod + def get_parser_stats_by_task_and_entity( + self, + task_id: str, + entity_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> Optional[Dict[str, Any]]: + """Parser-stats row keyed by (task_id, entity_id).""" + + @abstractmethod + def get_queue_items( + self, + queue_name: QueueRef, + status: Optional[str] = None, + sort_by: Optional[str] = None, + ascending: bool = True, + ) -> List[Dict[str, Any]]: + """Items from `queue_name`, optionally filtered by status and sorted.""" + + @abstractmethod + def find_queue_item_by_qid( + self, + queue_name: QueueRef, + qid: str, + ) -> Optional[Dict[str, Any]]: + """Single queue item matching `qid`, or None.""" + + @abstractmethod + def get_usage_records(self, use_dev_db: bool = False) -> List[Dict[str, Any]]: + """All API-usage records. `use_dev_db=True` reads the dev mirror.""" + + # ================================================================== + # Writes + # ================================================================== + @abstractmethod + def save_html_content(self, html_df: pd.DataFrame) -> None: + """Upsert HTML content rows from a DataFrame.""" + + @abstractmethod + def save_entailment_results(self, entailment_df: pd.DataFrame) -> None: + """Insert entailment result rows from a DataFrame.""" + + @abstractmethod + def save_parser_stats(self, stats_dict: Dict[str, Any]) -> None: + """Upsert parser-stats row keyed by (task_id, entity_id).""" + + @abstractmethod + def save_status( + self, + status_dict: Dict[str, Any], + queue: Optional[QueueRef] = None, + ) -> None: + """Upsert a status document, optionally in a non-default queue.""" + + @abstractmethod + def upsert_summary_by_id(self, target_id: str, data: Dict[str, Any]) -> None: + """Atomically insert-or-update the summary document for `target_id`.""" + + @abstractmethod + def enqueue_item(self, queue_name: QueueRef, item: Dict[str, Any]) -> None: + """Append `item` to `queue_name`.""" + + @abstractmethod + def log_usage(self, record: Dict[str, Any]) -> None: + """ + Persist a single API-usage record to the production usage store. + Implementations MUST swallow errors — a usage-log failure must never + surface to the caller. + """ + + # ================================================================== + # Queue-state mutations + # ================================================================== + @abstractmethod + def increment_retry_by_id( + self, + queue_name: QueueRef, + item_id: Any, + ) -> None: + """Atomically bump `retry_count` on the given queue item by 1.""" + + @abstractmethod + def mark_queue_item_error_by_id( + self, + queue_name: QueueRef, + item_id: Any, + error_message: str = "Max retry limit reached", + ) -> None: + """Flip a queue item to status='error' with a reason.""" + + @abstractmethod + def update_queue_status_by_task_and_qid( + self, + queue_name: QueueRef, + task_id: str, + qid: str, + status: str, + ) -> None: + """Set the `status` field on the (task_id, qid) row in `queue_name`.""" + + # ================================================================== + # Workflow primitives (existing queue-manager API) + # ================================================================== + # These predate Phase 1 and already follow a reasonable shape. They are + # part of the contract so Postgres must implement them too, but they + # take raw backend references (pymongo Collection today). Phase 3 will + # normalise them to queue names. + + @abstractmethod + def get_next_request(self, queue: Any) -> Optional[Dict[str, Any]]: + """Atomically claim the next pending item from a queue.""" + + @abstractmethod + def get_request_by_id_and_reset( + self, + queue: Any, + _id: str, + ) -> Optional[Dict[str, Any]]: + """Return a claimed item to 'in queue' state (undo a claim).""" + + @abstractmethod + def set_request_status_and_processing_time( + self, + queue: Any, + status: str, + processing_time: datetime, + _id: str, + ) -> Optional[Dict[str, Any]]: + """Set both `status` and `processing_start_timestamp` on an item.""" + + @abstractmethod + def get_request_by_id(self, queue: Any, _id: str) -> Optional[Dict[str, Any]]: + """Fetch a single request by its primary key.""" + + @abstractmethod + def get_request_by_taskid(self, queue: Any, task_id: str) -> Optional[Dict[str, Any]]: + """Fetch a single request by its `task_id`.""" + + @abstractmethod + def get_all_request_in_progress(self, queue: Any) -> Any: + """All items currently in 'processing' on a queue.""" diff --git a/prove-shared/src/prove_shared/database/mongo.py b/prove-shared/src/prove_shared/database/mongo.py new file mode 100644 index 0000000..0c4a009 --- /dev/null +++ b/prove-shared/src/prove_shared/database/mongo.py @@ -0,0 +1,1070 @@ +# @repo: shared +# @description: MongoDB implementation of DataStore (ABC) — owns every pymongo call in the codebase. Every other module accesses the DB through the methods defined here. +""" +MongoDB backend for the ProVe database layer. + +This module is the *only* place in the codebase that may import from pymongo +or instantiate `MongoClient`. Every other module accesses the database through +the `DataStore` methods defined here. + +Design notes: + * Inputs are plain Python types (str, dict, list). No BSON leaks out. + * Reads return Optional[dict] or List[dict] — pymongo `Cursor` objects are + always materialised with `list(...)` so callers can't accidentally + iterate twice (and so the same shape works for the Postgres backend). + * Collection/field names are hidden from callers: they pass semantic + args (qid, task_id, queue_name) and we resolve internals here. +""" +from typing import Any, Callable, Dict, List, Optional, Union +from datetime import datetime +from bson import ObjectId +import time +import uuid + +import pandas as pd +from pymongo import MongoClient, ReturnDocument +from pymongo.collection import Collection +from pymongo.database import Database + +from ..logger import logger +from .interface import DataStore + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +# Canonical database names are kept in one place so the schema can evolve +# without hunting through call sites. When the PostgreSQLHandler lands these +# become schema/table names in the Postgres implementation. +_MAIN_DB = "wikidata_verification" +_USAGE_DB_PROD = "service_usage" +_USAGE_DB_DEV = "tmp_service_usage" # dev/analysis mirror of prod usage data + + +class MongoDBHandler(DataStore): + """ + MongoDB implementation of `DataStore` (ABC). + + Owns all pymongo state for the process. The public methods form the + contract the future `PostgreSQLHandler` also implements, so the + `DatabaseOrchestrator` can swap backends at runtime. + + Usage DBs (`service_usage`, `tmp_service_usage`) are also owned here — + previously they lived in ad-hoc subclasses (`StatsDBHandler`, + `TMPStatsDBHandler`) which each re-opened `MongoClient`. Folding them in + collapses three classes into one and removes several leaked + `MongoClient(...)` call sites. + + Args: + connection_string: MongoDB connection URI. Defaults to + ``"mongodb://localhost:27017/"``. + max_retries: Maximum number of retries for the initial connection. + + Attributes: + client (MongoClient): Underlying pymongo client. + db (database): Handle to the main `wikidata_verification` database. + html_collection (collection): HTML-content documents. + entailment_collection (collection): BERT-FEVER entailment results. + stats_collection (collection): Parser statistics per (task, entity). + status_collection (collection): Task status timeline. + summary_collection (collection): Cached per-QID summaries. + random_collection (collection): Random-selection worker queue. + user_collection (collection): User-requested worker queue. + + Raises: + ConnectionError: If the initial connection cannot be established + after ``max_retries`` attempts. + """ + + def __init__( + self, + connection_string: str = "mongodb://localhost:27017/", + max_retries: int = 3, + ) -> None: + # Connection parameters + self.max_retries = max_retries + self.connection_string = connection_string + + # Main DB handles (populated in connect) + self.client: MongoClient = None + self.db: Database = None + self.html_collection: Collection = None + self.entailment_collection: Collection = None + self.stats_collection: Collection = None + self.status_collection: Collection = None + self.summary_collection: Collection = None + self.random_collection: Collection = None + self.user_collection: Collection = None + + # Queue lookup map — callers pass short string keys ("user", "random") + # rather than pymongo Collection objects. This keeps the public API + # backend-agnostic; Postgres will map the same keys to table names. + self._queues: Dict[str, Collection] = {} + + # Lazy handles for the usage databases. They live on the same Mongo + # instance but are separate *databases*. We only open them on first + # use because most code paths don't need them (only the @log_request + # decorator and the offline analytics script do). + self._usage_db_prod: Database = None + self._usage_db_dev: Database = None + + if not self.connect(max_retries, connection_string): + logger.error("Failed to connect to MongoDB") + raise ConnectionError("Could not connect to MongoDB after multiple attempts") + + # ------------------------------------------------------------------ # + # Connection lifecycle # + # ------------------------------------------------------------------ # + def connect(self, max_retries: int, connection_string: str) -> bool: + """ + Connect to MongoDB with retries and wire up collection handles. + + Args: + max_retries: Maximum number of retries before giving up. + connection_string: The MongoDB connection URI. + + Returns: + ``True`` if the connection is successful, ``False`` otherwise. + """ + for attempt in range(max_retries): + try: + self.client = MongoClient(connection_string) + self.ensure_connection(try_reconnect=False) + break + except Exception as e: + logger.error(f"MongoDB connection attempt {attempt + 1} failed: {e}") + if attempt == max_retries - 1: + return False + time.sleep(5) + continue + + # Main application database + self.db = self.client[_MAIN_DB] + self.html_collection = self.db['html_content'] + self.entailment_collection = self.db['entailment_results'] + self.stats_collection = self.db['parser_stats'] + self.status_collection = self.db['status'] + self.summary_collection = self.db['summary'] + self.random_collection = self.db['random_queue'] + self.user_collection = self.db['user_queue'] + + # Register queues under short keys. Public methods accept these keys + # instead of Collection objects, which is the shape Postgres will use. + self._queues = { + 'user': self.user_collection, + 'random': self.random_collection, + 'status': self.status_collection, + } + + # Index for high-concurrency dequeue on the random queue. + try: + self.random_collection.create_index([('status', 1), ('requested_timestamp', 1)]) + except Exception as e: + logger.error(f"Failed to create index {e}") + + logger.info("Successfully connected to WikiData verification MongoDB") + return True + + def ensure_connection(self, try_reconnect: bool = True) -> None: + """ + Verify the MongoDB connection is alive; reconnect if permitted. + + Args: + try_reconnect: If ``True``, attempt to reconnect when the + connection is down. If ``False``, raise immediately. + + Raises: + ConnectionError: If the connection is down and either reconnection + is disabled or retry attempts are exhausted. + """ + try: + self.client.server_info() + except Exception as e: + logger.error("MongoDB connection lost, attempting to reconnect...") + logger.error(f"Error details: {e}") + + if not try_reconnect: + logger.error("Reconnection failed, please check MongoDB server status") + raise ConnectionError("MongoDB connection lost") from e + + if not self.connect(self.max_retries, self.connection_string): + logger.error("Reconnection failed, please check MongoDB server status") + raise ConnectionError("Could not reconnect to MongoDB") from e + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + def _resolve_queue(self, queue: Union[str, Collection]) -> Collection: + """ + Resolve a queue reference to an underlying pymongo Collection. + + Accepts either a short queue name (preferred, backend-agnostic) or + a pymongo Collection (legacy). Dual acceptance keeps the migration + smooth: new callers pass ``"user"`` / ``"random"`` / ``"status"``; + older methods that already took Collection objects continue to work. + + Args: + queue: Either a short queue name or a pymongo Collection. + + Returns: + The underlying pymongo Collection. + + Raises: + ValueError: If ``queue`` is a string that isn't a known queue name. + """ + if isinstance(queue, str): + try: + return self._queues[queue] + except KeyError as e: + raise ValueError( + f"Unknown queue '{queue}'. Valid names: {list(self._queues)}" + ) from e + return queue + + # ================================================================== + # Bulk writes (existing methods, previously in the legacy handler) + # ================================================================== + def save_html_content(self, html_df: pd.DataFrame) -> None: + """ + Upsert HTML content rows from a DataFrame. + + The raw HTML column is dropped before insert for storage efficiency. + Rows are keyed on ``(reference_id, task_id)``; existing rows are + updated in place. + + Args: + html_df: DataFrame containing HTML content with columns: + ``reference_id``, ``task_id``, ``html``, ``fetch_timestamp``. + + Raises: + RuntimeError: If the batch write fails at the collection level. + Per-record errors are logged and skipped, not re-raised. + """ + try: + if html_df.empty: + logger.warning("html_df is empty") + return + + # Drop raw HTML — we store metadata, not the full pages. + html_df_without_html = html_df.drop('html', axis=1) + + logger.info(f"Attempting to save {len(html_df_without_html)} HTML records") + records = html_df_without_html.to_dict('records') + + for record in records: + try: + if 'reference_id' not in record: + logger.warning(f"Record missing reference_id: {record}") + continue + + if 'fetch_timestamp' in record and isinstance(record['fetch_timestamp'], pd.Timestamp): + record['fetch_timestamp'] = record['fetch_timestamp'].to_pydatetime() + + record['save_timestamp'] = datetime.now() + + result = self.html_collection.update_one( + { + 'reference_id': record['reference_id'], + 'task_id': record['task_id'] + }, + {'$set': record}, + upsert=True + ) + + logger.info( + f"Updated HTML document with reference_id {record['reference_id']}: " + f"matched={result.matched_count}, modified={result.modified_count}, " + f"upserted_id={result.upserted_id}" + ) + except Exception as e: + logger.error(f"Error saving HTML record: {record}") + logger.error(f"Error details: {e}") + + except Exception as e: + logger.error(f"Error in save_html_content: {e}") + raise RuntimeError(f"Failed to save HTML content: {e}") from e + + def save_entailment_results(self, entailment_df: pd.DataFrame) -> None: + """ + Insert entailment result rows from a DataFrame. + + Unlike ``save_html_content`` this is an append-only insert — we + deliberately keep the full history of entailment verdicts per + reference/claim pair for analytics. + + Args: + entailment_df: DataFrame containing entailment results with + columns including ``reference_id``, ``task_id``, + ``processed_timestamp``. + + Raises: + RuntimeError: If the batch insert fails. Per-record errors are + logged and skipped, not re-raised. + """ + try: + if entailment_df.empty: + logger.warning("entailment_df is empty") + return + + logger.info(f"Attempting to save {len(entailment_df)} entailment records") + records = entailment_df.to_dict('records') + + for record in records: + try: + if 'processed_timestamp' in record: + record['processed_timestamp'] = datetime.strptime( + record['processed_timestamp'], + '%Y-%m-%dT%H:%M:%S.%f' + ) + + record['save_timestamp'] = datetime.now() + + result = self.entailment_collection.insert_one(record) + logger.info( + f"Inserted new entailment document with reference_id {record['reference_id']}: " + f"inserted_id={result.inserted_id}" + ) + + except Exception as e: + logger.error(f"Error saving entailment record: {record}") + logger.error(f"Error details: {e}") + + except Exception as e: + logger.error(f"Error in save_entailment_results: {e}") + raise RuntimeError(f"Failed to save entailment results: {e}") from e + + def save_parser_stats(self, stats_dict: Dict[str, Any]) -> None: + """ + Upsert a parser-stats record keyed by ``(task_id, entity_id)``. + + Args: + stats_dict: Parser statistics with at least + ``entity_id``, ``task_id``, and ``parsing_start_timestamp``. + + Raises: + RuntimeError: If the upsert fails. + """ + try: + if isinstance(stats_dict.get('parsing_start_timestamp'), pd.Timestamp): + stats_dict['parsing_start_timestamp'] = stats_dict[ + 'parsing_start_timestamp' + ].to_pydatetime() + + stats_dict['save_timestamp'] = datetime.now() + + self.stats_collection.update_one( + { + 'entity_id': stats_dict['entity_id'], + 'task_id': stats_dict['task_id'] + }, + {'$set': stats_dict}, + upsert=True + ) + + logger.info(f"Updated parser stats for entity {stats_dict['entity_id']}") + + except Exception as e: + logger.error(f"Error in save_parser_stats: {e}") + raise RuntimeError(f"Failed to save parser statistics: {e}") from e + + def save_status( + self, + status_dict: Dict[str, Any], + queue: Union[str, Collection, None] = None, + ) -> None: + """ + Upsert a status document into the given queue. + + Existing records matching ``(task_id, qid)`` are updated in place; + otherwise a new record is inserted. String timestamps in the input + are parsed to ``datetime`` objects; ``last_updated`` is always set + to the current wall-clock time. + + Args: + status_dict: Status document. Must contain ``task_id`` and + ``qid``. Optional fields: ``status``, ``algo_version``, + ``request_type``, ``requested_timestamp``, + ``processing_start_timestamp``, ``completed_timestamp``. + queue: Either a short queue name (``"user"`` / ``"random"`` / + ``"status"``) or a pymongo Collection. Defaults to the main + status collection when omitted. + + Raises: + RuntimeError: If the upsert fails. + """ + target: Collection = ( + self.status_collection + if queue is None + else self._resolve_queue(queue) + ) + + try: + timestamp_fields = [ + 'requested_timestamp', + 'processing_start_timestamp', + 'completed_timestamp' + ] + + for field in timestamp_fields: + if status_dict.get(field) and status_dict[field] != 'null': + if isinstance(status_dict[field], str): + status_dict[field] = datetime.strptime( + status_dict[field].rstrip('Z'), + '%Y-%m-%dT%H:%M:%S.%f' + ) + + status_dict['last_updated'] = datetime.now() + + existing_doc = target.find_one({ + 'task_id': status_dict['task_id'], + 'qid': status_dict['qid'] + }) + + if existing_doc: + result = target.update_one( + { + 'task_id': status_dict['task_id'], + 'qid': status_dict['qid'] + }, + {'$set': status_dict} + ) + logger.info( + f"Updated status for task {status_dict['task_id']}: " + f"matched={result.matched_count}, modified={result.modified_count}" + ) + else: + result = target.insert_one(status_dict) + logger.info( + f"Created new status for task {status_dict['task_id']}: " + f"inserted_id={result.inserted_id}" + ) + + except Exception as e: + logger.error(f"Error in save_status: {e}") + raise RuntimeError(f"Failed to save status: {e}") from e + + def get_next_request(self, queue: Collection) -> Optional[Dict[str, Any]]: + """ + Atomically claim the next pending item from ``queue``. + + Uses ``find_one_and_update`` to flip status from ``"in queue"`` to + ``"processing"`` in a single round-trip, avoiding any race window + between read and write. + + Args: + queue: The pymongo Collection to pull the next request from. + + Returns: + The claimed document, or ``None`` if the queue is empty. + + Raises: + RuntimeError: If the claim operation fails. + """ + try: + pending_request = queue.find_one_and_update( + { + 'status': 'in queue', + 'processing_start_timestamp': None + }, + {'$set': { + 'status': 'processing', + 'processing_start_timestamp': datetime.utcnow(), + }}, + sort=[('requested_timestamp', 1)], + return_document=ReturnDocument.AFTER + ) + + if pending_request: + return pending_request + return None + except Exception as e: + logger.error(f"Error getting next user request: {e}") + raise RuntimeError(f"Failed to get next request: {e}") from e + + def get_request_by_id_and_reset( + self, + queue: Collection, + _id: str, + ) -> Optional[Dict[str, Any]]: + """ + Return an in-flight item to ``"in queue"`` state (undo a claim). + + Only matches items currently in ``"processing"``; a worker that + crashes mid-task calls this to release the row back to the pool. + + Args: + queue: The pymongo Collection containing the item. + _id: Primary key of the item. + + Returns: + The updated document post-reset, or ``None`` if no match. + """ + return queue.find_one_and_update( + { + '_id': _id, + 'status': 'processing', + 'processing_start_timestamp': {'$not': {'$eq': None}} + }, + {'$set': { + 'status': 'in queue', + 'processing_start_timestamp': None + }}, + return_document=ReturnDocument.AFTER + ) + + def set_request_status_and_processing_time( + self, + queue: Collection, + status: str, + processing_time: datetime, + _id: str, + ) -> Optional[Dict[str, Any]]: + """ + Overwrite both ``status`` and ``processing_start_timestamp`` on an item. + + Args: + queue: The pymongo Collection containing the item. + status: New status string. + processing_time: New ``processing_start_timestamp`` value. + _id: Primary key of the item. + + Returns: + The updated document post-write, or ``None`` if no match. + """ + return queue.find_one_and_update( + {'_id': _id}, + {'$set': { + 'status': status, + 'processing_start_timestamp': processing_time + }}, + return_document=ReturnDocument.AFTER + ) + + def get_request_by_id( + self, + queue: Collection, + _id: str, + ) -> Optional[Dict[str, Any]]: + """ + Fetch a single request by its primary key. + + Args: + queue: The pymongo Collection containing the item. + _id: Primary key (string; converted to ObjectId internally). + + Returns: + The document, or ``None`` if not found. + """ + return queue.find_one({'_id': ObjectId(_id)}) + + def get_request_by_taskid( + self, + queue: Collection, + task_id: str, + ) -> Optional[Dict[str, Any]]: + """ + Fetch a single request by its ``task_id``. + + Args: + queue: The pymongo Collection to search. + task_id: Task identifier. + + Returns: + The document, or ``None`` if not found. + """ + return queue.find_one({'task_id': task_id}) + + def get_all_request_in_progress(self, queue: Collection) -> Any: + """ + Return a cursor over items currently in ``"processing"``. + + Args: + queue: The pymongo Collection to scan. + + Returns: + A pymongo cursor. (Kept as a cursor for backward compatibility + with existing callers that iterate lazily.) + """ + return queue.find({'status': 'processing'}) + + # ================================================================== + # NEW METHODS — Phase 1 query consolidation + # ================================================================== + # Every raw pymongo call that previously lived outside this class now + # routes through one of the methods below. See the audit doc for the + # full old-call-site → new-method mapping. + + # ---- Status ------------------------------------------------------- # + def get_latest_status_by_qid(self, qid: str) -> Optional[Dict[str, Any]]: + """ + Return the most recent status document for a QID. + + "Most recent" is the row with the largest ``requested_timestamp``. + + Args: + qid: Wikidata identifier (e.g. ``"Q42"``). + + Returns: + The status document, or ``None`` if the QID has never been + enqueued. + """ + return self.status_collection.find_one( + {'qid': qid}, + sort=[('requested_timestamp', -1)], + ) + + def get_statuses_by_qid( + self, + qid: str, + sort_by: Optional[str] = None, + descending: bool = True, + ) -> List[Dict[str, Any]]: + """ + Return every status document for a QID. + + Args: + qid: Wikidata identifier. + sort_by: Optional field name to sort by. ``None`` preserves + insertion order (matches the pre-refactor semantics of + ``CheckItemStatus``). + descending: Sort direction when ``sort_by`` is given. + + Returns: + A list of status documents. Empty if the QID has no history. + """ + cursor = self.status_collection.find({'qid': qid}) + if sort_by is not None: + cursor = cursor.sort(sort_by, -1 if descending else 1) + return list(cursor) + + # ---- HTML content ------------------------------------------------- # + def get_html_by_task_id( + self, + task_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> List[Dict[str, Any]]: + """ + Return all HTML-content documents for a task. + + Args: + task_id: Task identifier. + fields: Optional Mongo-style projection ``{'name': 1, '_id': 0}``. + Accepted as a generic dict so the Postgres backend can + translate it to a SELECT column list. ``None`` fetches the + full document. + + Returns: + A list of HTML-content documents. + """ + if fields is None: + cursor = self.html_collection.find({'task_id': task_id}) + else: + cursor = self.html_collection.find({'task_id': task_id}, fields) + return list(cursor) + + # ---- Entailment results ------------------------------------------- # + def get_entailments_by_task_and_reference( + self, + task_id: str, + reference_id: str, + ) -> List[Dict[str, Any]]: + """ + Return entailment rows for a ``(task_id, reference_id)`` pair. + + Used by the per-reference verdict logic in ``GetItem``. Postgres + equivalent: ``WHERE task_id = %s AND reference_id = %s``. + + Args: + task_id: Task identifier. + reference_id: Reference identifier. + + Returns: + A list of entailment documents. Empty if no entailment has + been computed for this pair yet. + """ + return list( + self.entailment_collection.find( + {'task_id': task_id, 'reference_id': reference_id}, + ) + ) + + def aggregate_entailments_by_task_id( + self, + task_id: str, + reference_ids: List[str], + ) -> List[Dict[str, Any]]: + """ + Group top entailments for a task, bucketed by (reference_id, result). + + Results are sorted by ``text_entailment_score`` descending. Only + used by ``get_item`` when picking the top-scoring verdict per + reference; doing it server-side avoids pulling every entailment row + for a task. + + Args: + task_id: Task identifier. + reference_ids: Restrict aggregation to these references. + + Returns: + The raw aggregation output — a list of + ``{"_id": {"reference_id": ..., "result": ...}, "docs": [...]}`` + entries. The caller's existing grouping code expects this shape. + """ + pipeline = [ + {"$match": { + "task_id": task_id, + "reference_id": {"$in": reference_ids}, + }}, + {"$sort": {"text_entailment_score": -1}}, + {"$group": { + "_id": { + "reference_id": "$reference_id", + "result": "$result", + }, + "docs": {"$push": "$$ROOT"}, + }}, + ] + return list(self.entailment_collection.aggregate(pipeline)) + + # ---- Summaries ---------------------------------------------------- # + def get_summary_by_id(self, target_id: str) -> Optional[Dict[str, Any]]: + """ + Fetch the cached summary document for a QID. + + Args: + target_id: Wikidata identifier (used as the document ``_id``). + + Returns: + The summary document, or ``None`` if no summary has been + computed yet. + """ + return self.summary_collection.find_one({'_id': target_id}) + + def upsert_summary_by_id( + self, + target_id: str, + data: Dict[str, Any], + ) -> None: + """ + Insert or update the summary document for a QID. + + The previous implementation split this into explicit ``insert_one`` + and ``update_one`` branches depending on whether the document + already existed. That pattern was race-prone (two workers both see + "no doc" and both insert). ``upsert=True`` handles it atomically + and maps cleanly to Postgres' ``ON CONFLICT ... DO UPDATE``. + + Args: + target_id: Wikidata identifier (used as the document ``_id``). + data: Fields to set on the summary. + """ + self.summary_collection.update_one( + {'_id': target_id}, + {'$set': data}, + upsert=True, + ) + + # ---- Parser stats ------------------------------------------------- # + def get_parser_stats_by_task_and_entity( + self, + task_id: str, + entity_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> Optional[Dict[str, Any]]: + """ + Fetch a single parser-stats document keyed by ``(task_id, entity_id)``. + + Args: + task_id: Task identifier. + entity_id: Wikidata identifier. + fields: Optional projection. ``None`` (the default) returns just + ``{'total_claims': 1, '_id': 0}`` — the shape every current + caller needs. Pass an explicit projection to override, or + ``{}`` to fetch the full document. + + Returns: + The parser-stats document, or ``None`` if not found. + """ + projection = ( + {'total_claims': 1, '_id': 0} if fields is None else fields or None + ) + return self.stats_collection.find_one( + {'task_id': task_id, 'entity_id': entity_id}, + projection, + ) + + # ---- Generic queue operations ------------------------------------- # + def get_queue_items( + self, + queue_name: Union[str, Collection], + status: Optional[str] = None, + sort_by: Optional[str] = None, + ascending: bool = True, + ) -> List[Dict[str, Any]]: + """ + Return items from a queue, optionally filtered by status. + + Args: + queue_name: Either a short queue name (``"user"`` / ``"random"`` + / ``"status"``) or a pymongo Collection. + status: Optional status filter (e.g. ``"in queue"``, + ``"processing"``). ``None`` returns every row. + sort_by: Optional field to sort by. ``None`` preserves insertion + order. + ascending: Sort direction when ``sort_by`` is given. + + Returns: + A list of queue items. + """ + target = self._resolve_queue(queue_name) + query: Dict[str, Any] = {} + if status is not None: + query['status'] = status + + cursor = target.find(query) + if sort_by is not None: + cursor = cursor.sort(sort_by, 1 if ascending else -1) + return list(cursor) + + def find_queue_item_by_qid( + self, + queue_name: Union[str, Collection], + qid: str, + ) -> Optional[Dict[str, Any]]: + """ + Look up a single item by QID in a queue. + + Used by the heuristic service to avoid enqueuing a QID that already + exists in any known queue. + + Args: + queue_name: Either a short queue name or a pymongo Collection. + qid: Wikidata identifier. + + Returns: + The queue item, or ``None`` if absent. + """ + target = self._resolve_queue(queue_name) + return target.find_one({'qid': qid}) + + def increment_retry_by_id( + self, + queue_name: Union[str, Collection], + item_id: Any, + ) -> None: + """ + Atomically bump ``retry_count`` on a queue item by 1. + + Uses ``$inc`` (atomic in Mongo, ``UPDATE ... SET col = col + 1`` in + Postgres) to avoid the lost-update race the previous read-modify- + write implementation had. + + Args: + queue_name: Either a short queue name or a pymongo Collection. + item_id: Primary key of the item. + """ + target = self._resolve_queue(queue_name) + target.update_one({'_id': item_id}, {'$inc': {'retry_count': 1}}) + + def mark_queue_item_error_by_id( + self, + queue_name: Union[str, Collection], + item_id: Any, + error_message: str = 'Max retry limit reached', + ) -> None: + """ + Mark a queue item as permanently failed (``status = "error"``). + + Called by the retry loop when an item has exhausted its retries. + + Args: + queue_name: Either a short queue name or a pymongo Collection. + item_id: Primary key of the item. + error_message: Human-readable reason for the failure. + """ + target = self._resolve_queue(queue_name) + target.update_one( + {'_id': item_id}, + {'$set': {'status': 'error', 'error_message': error_message}}, + ) + + def update_queue_status_by_task_and_qid( + self, + queue_name: Union[str, Collection], + task_id: str, + qid: str, + status: str, + ) -> None: + """ + Update the ``status`` field for a ``(task_id, qid)`` row. + + Called by the main service when a job transitions to ``"completed"``. + + Args: + queue_name: Either a short queue name or a pymongo Collection. + task_id: Task identifier. + qid: Wikidata identifier. + status: New status string. + """ + target = self._resolve_queue(queue_name) + target.update_one( + {'task_id': task_id, 'qid': qid}, + {'$set': {'status': status}}, + ) + + def enqueue_item( + self, + queue_name: Union[str, Collection], + item: Dict[str, Any], + ) -> None: + """ + Append a new item to a queue — the minimal "append" operation. + + Replaces the old pattern of passing ``collection.insert_one`` around + as a raw callback, which exposed pymongo to every caller. + + Args: + queue_name: Either a short queue name or a pymongo Collection. + item: The document to insert. + """ + target = self._resolve_queue(queue_name) + target.insert_one(item) + + # ================================================================== + # USAGE DBs (service_usage / tmp_service_usage) + # ================================================================== + # Separate databases on the same Mongo instance, used by the API + # decorators to log every request and by the offline analytics script + # (info.py) to read them. + # + # Previously implemented as two subclasses (StatsDBHandler, + # TMPStatsDBHandler) each overriding `connect()` to point at a different + # DB. That pattern leaked raw MongoClient access to every usage call + # site. Folding them in here means a single handler now owns all Mongo + # access; the DB handles are created lazily on first property access. + # ------------------------------------------------------------------ + + @property + def usage_collection(self) -> Collection: + """ + Lazy handle to ``service_usage.usage`` (production request logs). + + The database handle is opened on first access and cached for the + lifetime of the process. + """ + if self._usage_db_prod is None: + self._usage_db_prod = self.client[_USAGE_DB_PROD] + return self._usage_db_prod['usage'] + + @property + def tmp_usage_collection(self) -> Collection: + """Lazy handle to ``tmp_service_usage.usage`` (dev/analysis mirror).""" + if self._usage_db_dev is None: + self._usage_db_dev = self.client[_USAGE_DB_DEV] + return self._usage_db_dev['usage'] + + def log_usage(self, record: Dict[str, Any]) -> None: + """ + Persist a single API-usage record to the production usage DB. + + Called from the ``@log_request`` decorator on every HTTP request. + Errors are intentionally swallowed — a usage-logging failure must + never surface to the user as a 500. + + Args: + record: Usage record (method, url, headers, body, timestamp, + execution_time). + """ + try: + self.usage_collection.insert_one(record) + except Exception as e: + logger.error(f"Failed to log usage record: {e}") + + def get_usage_records( + self, + use_dev_db: bool = False, + ) -> List[Dict[str, Any]]: + """ + Return every usage record for offline analysis. + + Args: + use_dev_db: If ``True``, read from ``tmp_service_usage`` (the + dev mirror used by ``info.py``). Otherwise read production. + + Returns: + A materialised list of usage records (not a cursor) so callers + can use ``len()``, ``tqdm``, and repeated iteration. + """ + target = self.tmp_usage_collection if use_dev_db else self.usage_collection + return list(target.find()) + + +# --------------------------------------------------------------------------- +# Free-function helpers +# --------------------------------------------------------------------------- +def requestItemProcessing( + qid: str, + queue: Union[str, Collection], + db: "MongoDBHandler", + request_type: str = 'userRequested', + algo_version: str = '1.1.1', +) -> str: + """ + Enqueue a QID for processing if it isn't already pending. + + Thin wrapper over the handler's ``find_queue_item_by_qid`` + + ``enqueue_item`` primitives. Kept as a module-level function because + several callers import it directly. + + Args: + qid: Wikidata identifier to enqueue. + queue: Either a queue name (``"user"`` / ``"random"``) or a pymongo + Collection. + db: The handler instance to route reads/writes through. + request_type: Origin tag for the request (``"userRequested"``, + ``"Random_processing"``, ``"top_viewed"``, ...). + algo_version: Pipeline version stamped on the new record. + + Returns: + A human-readable status string (callers log it). + """ + try: + existing = db.find_queue_item_by_qid(queue, qid) + if existing and existing.get('status') == 'in queue': + return f"QID {qid} is already in queue. Skipping..." + + status_dict = _build_status_dict(qid, request_type, algo_version) + db.enqueue_item(queue, status_dict) + return f"Task {status_dict['task_id']} created for QID {qid}" + + except Exception as e: + logger.error("Error in requestItemProcessing: %s", e) + return f"An error occurred: {e}" + + +def _build_status_dict( + qid: str, + request_type: str, + algo_version: str, +) -> Dict[str, Any]: + """ + Construct the canonical "in queue" status document for a new task. + + Args: + qid: Wikidata identifier. + request_type: Origin tag for the request. + algo_version: Pipeline version. + + Returns: + A status dictionary ready for insertion. + """ + return { + 'qid': qid, + 'task_id': str(uuid.uuid4()), + 'status': 'in queue', + 'algo_version': algo_version, + 'request_type': request_type, + 'requested_timestamp': datetime.utcnow(), + 'processing_start_timestamp': None, + 'completed_timestamp': None, + } diff --git a/prove-shared/src/prove_shared/database/orchestrator.py b/prove-shared/src/prove_shared/database/orchestrator.py new file mode 100644 index 0000000..3655b25 --- /dev/null +++ b/prove-shared/src/prove_shared/database/orchestrator.py @@ -0,0 +1,462 @@ +# @repo: shared +# @description: Config-driven database orchestrator. Selects primary/fallback backends from YAML and exposes the same DataStore (ABC) so callers are backend-agnostic. +""" +DatabaseOrchestrator — chooses and wraps the active backend. + +The orchestrator is a thin, transparent wrapper around a primary +`DataStore` implementation, with two optional behaviours: + + 1. Dual-write mode (for migration) — writes go to the primary + *and* the fallback. Writes to the fallback are log-and-continue: + a failure there never breaks the primary write, because during + migration the primary is still the source of truth. + + 2. Read fallback (opt-in) — if a read against the primary raises, + try the fallback before re-raising. Off by default because a silent + fallback on reads hides real backend outages. + +Config shape (loaded from the app's existing `config.yaml`): + + database: + primary: mongo # "mongo" | "postgres" + fallback: none # "none" | "mongo" | "postgres" + mode: single # "single" | "dual-write" + auto_fallback_on_read: false + + mongo: + connection_string: mongodb://localhost:27017/ + max_retries: 3 + postgres: + dsn: postgresql://localhost/prove + max_retries: 3 + +Migration phases (same settings, different values): + + Day 1 primary=mongo, fallback=none, mode=single + Migration primary=mongo, fallback=postgres, mode=dual-write + Cutover primary=postgres, fallback=mongo, mode=single, auto_fallback_on_read=true + Done primary=postgres, fallback=none, mode=single +""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional + +import pandas as pd +import yaml + +from ..logger import logger +from .interface import DataStore, QueueRef + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- +# `get_database()` is called once per process in most cases (module-level in +# each caller). We cache the result so repeated calls don't reopen connections. +# Tests can pass `config=...` explicitly to bypass the cache. +_CACHED_DB: Optional[DataStore] = None + + +# --------------------------------------------------------------------------- +# Backend factory +# --------------------------------------------------------------------------- +def _build_backend(kind: str, settings: Dict[str, Any]) -> DataStore: + """ + Construct a concrete `DataStore` implementation. + + Kept here rather than in the backend modules themselves so that the + config-shape-to-constructor-arg mapping lives in one place. Adding a new + backend means one extra branch here and a new file alongside mongo.py / + postgres.py. + """ + if kind == "mongo": + # Imported lazily so `PostgreSQLHandler`-only environments don't need + # pymongo on the path (useful for future slimmed-down deployments). + from .mongo import MongoDBHandler + + return MongoDBHandler( + connection_string=settings.get( + "connection_string", "mongodb://localhost:27017/" + ), + max_retries=settings.get("max_retries", 3), + ) + if kind == "postgres": + from .postgres import PostgreSQLHandler + + return PostgreSQLHandler( + dsn=settings.get("dsn", "postgresql://localhost/prove"), + max_retries=settings.get("max_retries", 3), + ) + raise ValueError(f"Unknown database backend: {kind!r}") + + +class DatabaseOrchestrator(DataStore): + """ + Routes reads and writes between a primary and an optional fallback. + + Implements `DataStore` itself so callers hold a single object + and never know which backend is serving any given call. Method bodies + are deliberately explicit (rather than `__getattr__`-based) so IDEs, + type checkers, and stack traces point at the right thing. + """ + + def __init__( + self, + primary: DataStore, + fallback: Optional[DataStore] = None, + dual_write: bool = False, + auto_fallback_on_read: bool = False, + ) -> None: + self.primary = primary + self.fallback = fallback + self.dual_write = dual_write + self.auto_fallback_on_read = auto_fallback_on_read + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + def _read(self, method_name: str, *args, **kwargs): + """ + Call `method_name` on the primary. If it fails and read-fallback is + enabled, retry on the fallback. Otherwise re-raise. + """ + try: + return getattr(self.primary, method_name)(*args, **kwargs) + except Exception as e: + if self.auto_fallback_on_read and self.fallback is not None: + logger.warning( + f"Primary read '{method_name}' failed ({e}); " + f"falling back to secondary backend." + ) + return getattr(self.fallback, method_name)(*args, **kwargs) + raise + + def _write(self, method_name: str, *args, **kwargs) -> None: + """ + Always write to the primary. In dual-write mode, also write to the + fallback with log-and-continue semantics — during migration the + primary is source of truth and a fallback failure must not break + the user-facing write. + """ + getattr(self.primary, method_name)(*args, **kwargs) + + if self.dual_write and self.fallback is not None: + try: + getattr(self.fallback, method_name)(*args, **kwargs) + except Exception as e: + # Log and move on — dual-write is best-effort by design. + # TODO: once Postgres is the primary, revisit this policy — + # we may want stricter behaviour (reconciliation job, alert). + logger.error( + f"Dual-write of '{method_name}' to fallback failed: {e}. " + "Primary succeeded; continuing." + ) + + # ================================================================== + # Connection lifecycle + # ================================================================== + def ensure_connection(self, try_reconnect: bool = True) -> None: + self.primary.ensure_connection(try_reconnect=try_reconnect) + # Don't fail the caller just because the fallback is down — that + # would defeat the point of having a fallback. + if self.fallback is not None: + try: + self.fallback.ensure_connection(try_reconnect=try_reconnect) + except Exception as e: + logger.warning(f"Fallback connection check failed: {e}") + + # ================================================================== + # Reads — simple delegations through `_read` + # ================================================================== + def get_latest_status_by_qid(self, qid: str) -> Optional[Dict[str, Any]]: + return self._read("get_latest_status_by_qid", qid) + + def get_statuses_by_qid( + self, + qid: str, + sort_by: Optional[str] = None, + descending: bool = True, + ) -> List[Dict[str, Any]]: + return self._read("get_statuses_by_qid", qid, sort_by=sort_by, descending=descending) + + def get_html_by_task_id( + self, + task_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> List[Dict[str, Any]]: + return self._read("get_html_by_task_id", task_id, fields=fields) + + def get_entailments_by_task_and_reference( + self, + task_id: str, + reference_id: str, + ) -> List[Dict[str, Any]]: + return self._read("get_entailments_by_task_and_reference", task_id, reference_id) + + def aggregate_entailments_by_task_id( + self, + task_id: str, + reference_ids: List[str], + ) -> List[Dict[str, Any]]: + return self._read("aggregate_entailments_by_task_id", task_id, reference_ids) + + def get_summary_by_id(self, target_id: str) -> Optional[Dict[str, Any]]: + return self._read("get_summary_by_id", target_id) + + def get_parser_stats_by_task_and_entity( + self, + task_id: str, + entity_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> Optional[Dict[str, Any]]: + return self._read( + "get_parser_stats_by_task_and_entity", + task_id, entity_id, fields=fields, + ) + + def get_queue_items( + self, + queue_name: QueueRef, + status: Optional[str] = None, + sort_by: Optional[str] = None, + ascending: bool = True, + ) -> List[Dict[str, Any]]: + return self._read( + "get_queue_items", + queue_name, status=status, sort_by=sort_by, ascending=ascending, + ) + + def find_queue_item_by_qid( + self, + queue_name: QueueRef, + qid: str, + ) -> Optional[Dict[str, Any]]: + return self._read("find_queue_item_by_qid", queue_name, qid) + + def get_usage_records(self, use_dev_db: bool = False) -> List[Dict[str, Any]]: + return self._read("get_usage_records", use_dev_db=use_dev_db) + + # ================================================================== + # Writes — routed through `_write` + # ================================================================== + def save_html_content(self, html_df: pd.DataFrame) -> None: + self._write("save_html_content", html_df) + + def save_entailment_results(self, entailment_df: pd.DataFrame) -> None: + self._write("save_entailment_results", entailment_df) + + def save_parser_stats(self, stats_dict: Dict[str, Any]) -> None: + self._write("save_parser_stats", stats_dict) + + def save_status( + self, + status_dict: Dict[str, Any], + queue: Optional[QueueRef] = None, + ) -> None: + self._write("save_status", status_dict, queue=queue) + + def upsert_summary_by_id(self, target_id: str, data: Dict[str, Any]) -> None: + self._write("upsert_summary_by_id", target_id, data) + + def enqueue_item(self, queue_name: QueueRef, item: Dict[str, Any]) -> None: + self._write("enqueue_item", queue_name, item) + + def log_usage(self, record: Dict[str, Any]) -> None: + # Interface contract: log_usage must never raise. The wrapped + # implementations already swallow errors, but wrap belt-and-braces + # because dual-write could surface a fallback error here too. + try: + self._write("log_usage", record) + except Exception as e: + logger.error(f"Orchestrator log_usage swallowed error: {e}") + + # ================================================================== + # Queue-state mutations + # ================================================================== + def increment_retry_by_id( + self, + queue_name: QueueRef, + item_id: Any, + ) -> None: + self._write("increment_retry_by_id", queue_name, item_id) + + def mark_queue_item_error_by_id( + self, + queue_name: QueueRef, + item_id: Any, + error_message: str = "Max retry limit reached", + ) -> None: + self._write( + "mark_queue_item_error_by_id", + queue_name, item_id, error_message=error_message, + ) + + def update_queue_status_by_task_and_qid( + self, + queue_name: QueueRef, + task_id: str, + qid: str, + status: str, + ) -> None: + self._write( + "update_queue_status_by_task_and_qid", + queue_name, task_id, qid, status, + ) + + # ================================================================== + # Workflow primitives + # ================================================================== + def get_next_request(self, queue: Any) -> Optional[Dict[str, Any]]: + # Side effect: this method both reads *and* writes (claims the row). + # Treat it as a write — dual-write would double-claim, which is wrong, + # so we intentionally only hit the primary here. + return self.primary.get_next_request(queue) + + def get_request_by_id_and_reset( + self, + queue: Any, + _id: str, + ) -> Optional[Dict[str, Any]]: + # Same rationale as `get_next_request` — primary only. + return self.primary.get_request_by_id_and_reset(queue, _id) + + def set_request_status_and_processing_time( + self, + queue: Any, + status: str, + processing_time, + _id: str, + ) -> Optional[Dict[str, Any]]: + return self.primary.set_request_status_and_processing_time( + queue, status, processing_time, _id, + ) + + def get_request_by_id(self, queue: Any, _id: str) -> Optional[Dict[str, Any]]: + return self._read("get_request_by_id", queue, _id) + + def get_request_by_taskid(self, queue: Any, task_id: str) -> Optional[Dict[str, Any]]: + return self._read("get_request_by_taskid", queue, task_id) + + def get_all_request_in_progress(self, queue: Any) -> Any: + return self._read("get_all_request_in_progress", queue) + + # ================================================================== + # Convenience pass-throughs for Mongo-specific attributes + # ================================================================== + # Several legacy callers reach into `.user_collection` / `.random_collection` + # on the handler (these are pymongo Collection objects). We expose them + # here as a thin proxy to the primary so the transition is invisible. + # Postgres code paths won't have these attributes; callers that still use + # them are on the "things to migrate to queue_name-based API" list. + def __getattr__(self, name: str) -> Any: + # Only invoked when the attribute isn't found the normal way, so our + # method definitions above always win. + return getattr(self.primary, name) + + +# --------------------------------------------------------------------------- +# Public factory +# --------------------------------------------------------------------------- +def _load_database_config(config_path: str) -> Dict[str, Any]: + """Read and return just the `database:` block from the given YAML file.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) or {} + return cfg.get("database", {}) or {} + + +def _resolve_config_path(config_path: Optional[str]) -> str: + """ + Decide which config.yaml to load. + + Priority: + 1. Explicit `config_path` argument. + 2. `PROVE_CONFIG_PATH` env var (used in tests / containers). + 3. `config.yaml` in the current working directory (matches existing + behaviour of both prove-api and prove-processing). + """ + if config_path: + return config_path + env_path = os.environ.get("PROVE_CONFIG_PATH") + if env_path: + return env_path + return "config.yaml" + + +def get_database( + config_path: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + refresh: bool = False, +) -> DataStore: + """ + Return the configured database handler. + + Args: + config_path: Path to a YAML file with a `database:` block. Ignored if + `config` is provided. + config: Inline `database:` config dict — useful for tests, skips the + module-level cache. + refresh: Rebuild the cached instance even if one already exists. + + Returns: + A `DataStore` implementation. When `fallback` is None and + `mode == "single"`, the concrete backend is returned directly (no + orchestrator wrapper) so callers pay zero extra cost. As soon as + fallback/dual-write is configured, the orchestrator kicks in. + """ + global _CACHED_DB + + # Inline config is never cached — intended for tests/one-offs. + if config is not None: + return _build_from_config(config) + + if _CACHED_DB is not None and not refresh: + return _CACHED_DB + + try: + db_cfg = _load_database_config(_resolve_config_path(config_path)) + except FileNotFoundError: + # Sensible default for fresh checkouts: single Mongo backend at localhost. + # Matches the current default used throughout the codebase so nothing + # regresses if someone forgets to add the `database:` block. + logger.warning( + "config.yaml not found or has no `database:` block; " + "defaulting to single Mongo backend." + ) + db_cfg = {} + + _CACHED_DB = _build_from_config(db_cfg) + return _CACHED_DB + + +def _build_from_config(db_cfg: Dict[str, Any]) -> DataStore: + """Translate a `database:` config dict into a concrete backend/orchestrator.""" + primary_kind = db_cfg.get("primary", "mongo") + fallback_kind = db_cfg.get("fallback", "none") + mode = db_cfg.get("mode", "single") + auto_fallback_on_read = bool(db_cfg.get("auto_fallback_on_read", False)) + + primary_settings = db_cfg.get(primary_kind, {}) or {} + primary = _build_backend(primary_kind, primary_settings) + + # Fast path: no fallback, single-mode → skip the orchestrator entirely. + if fallback_kind in (None, "none", "") and mode == "single" and not auto_fallback_on_read: + return primary + + fallback: Optional[DataStore] = None + if fallback_kind not in (None, "none", ""): + fallback_settings = db_cfg.get(fallback_kind, {}) or {} + fallback = _build_backend(fallback_kind, fallback_settings) + + return DatabaseOrchestrator( + primary=primary, + fallback=fallback, + dual_write=(mode == "dual-write"), + auto_fallback_on_read=auto_fallback_on_read, + ) + + +def reset_cached_database() -> None: + """Clear the module-level cache. Intended for tests.""" + global _CACHED_DB + _CACHED_DB = None diff --git a/prove-shared/src/prove_shared/database/postgres.py b/prove-shared/src/prove_shared/database/postgres.py new file mode 100644 index 0000000..6591adf --- /dev/null +++ b/prove-shared/src/prove_shared/database/postgres.py @@ -0,0 +1,241 @@ +# @repo: shared +# @description: PostgreSQL implementation of DataStore (ABC). Stub — all methods raise NotImplementedError until the migration work begins. +""" +PostgreSQLHandler — placeholder implementation of `DataStore` (ABC). + +This class exists so the interface is provably pluggable today: the +orchestrator can be constructed with it, tests can mock it, and future work +can fill in one method at a time without touching the interface or the +Mongo implementation. + +Every method raises `NotImplementedError` with a short hint about what the +Postgres equivalent should do. When a method is implemented, keep the +signature identical to the interface — no extra params, no different +return shape. + +Intended dependencies (added to `pyproject.toml` when we start +implementation): + * `psycopg[binary]>=3.1` or `psycopg2-binary>=2.9` + * `SQLAlchemy>=2.0` (optional, for migrations/models) + * `alembic>=1.13` (schema migration tool) + +Schema TODOs (to be defined before any method is implemented): + * `status(task_id PK, qid, status, algo_version, request_type, + requested_timestamp, processing_start_timestamp, + completed_timestamp, last_updated)` + * `html_content(reference_id, task_id, url, lang, status, save_timestamp, + fetch_timestamp, entity_label, property_label, + object_label, object_id, property_id, + PRIMARY KEY(reference_id, task_id))` + * `entailment_results(id PK, reference_id, task_id, result, + text_entailment_score, result_sentence, + processed_timestamp, save_timestamp)` + * `parser_stats(task_id, entity_id, total_claims, parsing_start_timestamp, + save_timestamp, PRIMARY KEY(task_id, entity_id))` + * `summary(id PK, algo_version, last_update, status, total_claims, + prove_score, count_refuting, count_inconclusive, + count_supportive, count_irretrievable)` + * `queue_user` / `queue_random` / `queue_status` — same columns as + `status`, separate tables so indexes and concurrency characteristics + can diverge if needed. + * `usage_log(id PK, method, url, headers JSONB, body JSONB, timestamp, + execution_time)` — one table; prod vs dev mirror is a + schema or an env switch, not a separate DB. + +General notes for whoever writes the real thing: + * All `upsert_*` methods map to `INSERT ... ON CONFLICT (...) DO UPDATE`. + * `$inc` maps to `UPDATE ... SET col = col + 1`. + * `aggregate_entailments_by_task_id` can be one query with + `ROW_NUMBER() OVER (PARTITION BY reference_id, result ORDER BY score DESC)` + plus filtering to keep the "top per bucket". + * `get_next_request` needs `FOR UPDATE SKIP LOCKED` to match the atomic + claim semantics of the current Mongo `find_one_and_update`. +""" +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional + +import pandas as pd + +from .interface import DataStore, QueueRef + + +def _not_implemented(method_name: str) -> NotImplementedError: + """Small helper so the message style stays consistent across methods.""" + return NotImplementedError( + f"PostgreSQLHandler.{method_name} is not implemented yet. " + "See module docstring for schema and implementation notes." + ) + + +class PostgreSQLHandler(DataStore): + """Stub Postgres backend. Populate one method at a time.""" + + def __init__( + self, + dsn: str = "postgresql://localhost/prove", + max_retries: int = 3, + ) -> None: + # TODO: open a psycopg / SQLAlchemy connection pool using `dsn`. + # Keep `max_retries` semantics consistent with MongoDBHandler so the + # orchestrator's retry behaviour is identical across backends. + self.dsn = dsn + self.max_retries = max_retries + + # ------------------------------------------------------------------ # + # Connection lifecycle # + # ------------------------------------------------------------------ # + def ensure_connection(self, try_reconnect: bool = True) -> None: + raise _not_implemented("ensure_connection") + + # ---- Reads -------------------------------------------------------- # + def get_latest_status_by_qid(self, qid: str) -> Optional[Dict[str, Any]]: + # TODO: SELECT * FROM status WHERE qid = %s ORDER BY requested_timestamp DESC LIMIT 1 + raise _not_implemented("get_latest_status_by_qid") + + def get_statuses_by_qid( + self, + qid: str, + sort_by: Optional[str] = None, + descending: bool = True, + ) -> List[Dict[str, Any]]: + raise _not_implemented("get_statuses_by_qid") + + def get_html_by_task_id( + self, + task_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> List[Dict[str, Any]]: + # TODO: translate Mongo projection `{'col': 1, ...}` into a SELECT column list. + raise _not_implemented("get_html_by_task_id") + + def get_entailments_by_task_and_reference( + self, + task_id: str, + reference_id: str, + ) -> List[Dict[str, Any]]: + raise _not_implemented("get_entailments_by_task_and_reference") + + def aggregate_entailments_by_task_id( + self, + task_id: str, + reference_ids: List[str], + ) -> List[Dict[str, Any]]: + # TODO: window function approach — see module docstring. + raise _not_implemented("aggregate_entailments_by_task_id") + + def get_summary_by_id(self, target_id: str) -> Optional[Dict[str, Any]]: + raise _not_implemented("get_summary_by_id") + + def get_parser_stats_by_task_and_entity( + self, + task_id: str, + entity_id: str, + fields: Optional[Dict[str, int]] = None, + ) -> Optional[Dict[str, Any]]: + raise _not_implemented("get_parser_stats_by_task_and_entity") + + def get_queue_items( + self, + queue_name: QueueRef, + status: Optional[str] = None, + sort_by: Optional[str] = None, + ascending: bool = True, + ) -> List[Dict[str, Any]]: + raise _not_implemented("get_queue_items") + + def find_queue_item_by_qid( + self, + queue_name: QueueRef, + qid: str, + ) -> Optional[Dict[str, Any]]: + raise _not_implemented("find_queue_item_by_qid") + + def get_usage_records(self, use_dev_db: bool = False) -> List[Dict[str, Any]]: + raise _not_implemented("get_usage_records") + + # ---- Writes ------------------------------------------------------- # + def save_html_content(self, html_df: pd.DataFrame) -> None: + raise _not_implemented("save_html_content") + + def save_entailment_results(self, entailment_df: pd.DataFrame) -> None: + raise _not_implemented("save_entailment_results") + + def save_parser_stats(self, stats_dict: Dict[str, Any]) -> None: + raise _not_implemented("save_parser_stats") + + def save_status( + self, + status_dict: Dict[str, Any], + queue: Optional[QueueRef] = None, + ) -> None: + raise _not_implemented("save_status") + + def upsert_summary_by_id(self, target_id: str, data: Dict[str, Any]) -> None: + # TODO: INSERT ... ON CONFLICT (id) DO UPDATE SET ... + raise _not_implemented("upsert_summary_by_id") + + def enqueue_item(self, queue_name: QueueRef, item: Dict[str, Any]) -> None: + raise _not_implemented("enqueue_item") + + def log_usage(self, record: Dict[str, Any]) -> None: + # TODO: INSERT INTO usage_log (...) VALUES (...) + # Remember: must swallow errors per the interface contract. + raise _not_implemented("log_usage") + + # ---- Queue-state mutations --------------------------------------- # + def increment_retry_by_id( + self, + queue_name: QueueRef, + item_id: Any, + ) -> None: + # TODO: UPDATE SET retry_count = retry_count + 1 WHERE id = %s + raise _not_implemented("increment_retry_by_id") + + def mark_queue_item_error_by_id( + self, + queue_name: QueueRef, + item_id: Any, + error_message: str = "Max retry limit reached", + ) -> None: + raise _not_implemented("mark_queue_item_error_by_id") + + def update_queue_status_by_task_and_qid( + self, + queue_name: QueueRef, + task_id: str, + qid: str, + status: str, + ) -> None: + raise _not_implemented("update_queue_status_by_task_and_qid") + + # ---- Workflow primitives ----------------------------------------- # + def get_next_request(self, queue: Any) -> Optional[Dict[str, Any]]: + # TODO: SELECT ... FOR UPDATE SKIP LOCKED; UPDATE ... RETURNING * + raise _not_implemented("get_next_request") + + def get_request_by_id_and_reset( + self, + queue: Any, + _id: str, + ) -> Optional[Dict[str, Any]]: + raise _not_implemented("get_request_by_id_and_reset") + + def set_request_status_and_processing_time( + self, + queue: Any, + status: str, + processing_time: datetime, + _id: str, + ) -> Optional[Dict[str, Any]]: + raise _not_implemented("set_request_status_and_processing_time") + + def get_request_by_id(self, queue: Any, _id: str) -> Optional[Dict[str, Any]]: + raise _not_implemented("get_request_by_id") + + def get_request_by_taskid(self, queue: Any, task_id: str) -> Optional[Dict[str, Any]]: + raise _not_implemented("get_request_by_taskid") + + def get_all_request_in_progress(self, queue: Any) -> Any: + raise _not_implemented("get_all_request_in_progress") diff --git a/prove-shared/src/prove_shared/mongo_handler.py b/prove-shared/src/prove_shared/mongo_handler.py deleted file mode 100644 index 405ccf0..0000000 --- a/prove-shared/src/prove_shared/mongo_handler.py +++ /dev/null @@ -1,482 +0,0 @@ -# @repo: shared -# @description: MongoDB abstraction layer — manages all collections (html_content, entailment_results, status, queues); used by both API (reads) and processing (writes) -from typing import Dict, Any, Callable, Union -from datetime import datetime -from bson import ObjectId -import time -import uuid - -import pandas as pd -from pymongo import MongoClient, collection, database, ReturnDocument - -from .logger import logger - - -class MongoDBHandler: - """ - MongoDBHandler is a class that manages the connection to a MongoDB database and provides methods - to save HTML content, entailment results, parser statistics, and other data related to WikiData - verification tasks. It handles connection retries, ensures the connection is alive, and - provides methods to save various types of data with appropriate error handling and logging. - - Args: - connection_string (str): The MongoDB connection string. Defaults to "mongodb://localhost:27017/". - max_retries (int): Maximum number of retries for connecting to MongoDB. Defaults to 3. - - Attributes: - max_retries (int): Maximum number of retries for connecting to MongoDB. - connection_string (str): The MongoDB connection string. - client (MongoClient): The MongoDB client instance. - db (database): The database instance. - html_collection (collection): Collection for storing HTML content. - entailment_collection (collection): Collection for storing entailment results. - stats_collection (collection): Collection for storing parser statistics. - status_collection (collection): Collection for storing task status. - summary_collection (collection): Collection for storing task summaries. - random_collection (collection): Singular queue for random tasks. - user_collection (collection): Singular queue for user tasks. - - Raises: - ConnectionError: If the connection to MongoDB fails after the maximum number of retries. - """ - def __init__( - self, - connection_string: str = "mongodb://localhost:27017/", - max_retries: int = 3, - ) -> None: - # MongoDB connection parameters - self.max_retries = max_retries - self.connection_string = connection_string - - # Initialize MongoDB client and collections attributes - self.client: MongoClient = None - self.db: database = None - self.html_collection: collection = None - self.entailment_collection: collection = None - self.stats_collection: collection = None - self.status_collection: collection = None - self.summary_collection: collection = None - self.random_collection: collection = None - self.user_collection: collection = None - - # Attempt to connect to MongoDB - if not self.connect(max_retries, connection_string): - logger.error("Failed to connect to MongoDB") - raise ConnectionError("Could not connect to MongoDB after multiple attempts") - - def connect(self, max_retries: int, connection_string: str) -> bool: - """ - Connect to MongoDB with retries. - - Args: - max_retries (int): Maximum number of retries for connecting to MongoDB. - connection_string (str): The MongoDB connection string. - - Returns: - bool: True if the connection is successful, False otherwise. - """ - for attempt in range(max_retries): - try: - self.client = MongoClient(connection_string) - self.ensure_connection(try_reconnect=False) - break - except Exception as e: - logger.error(f"MongoDB connection attempt {attempt + 1} failed: {e}") - if attempt == max_retries - 1: - return False - time.sleep(5) - continue - - # Access the database and collections - self.db = self.client['wikidata_verification'] - self.html_collection = self.db['html_content'] - self.entailment_collection = self.db['entailment_results'] - self.stats_collection = self.db['parser_stats'] - self.status_collection = self.db['status'] - self.summary_collection = self.db['summary'] - - # Singular queues - self.random_collection = self.db['random_queue'] - self.user_collection = self.db['user_queue'] - - # Set indexes for high concurrency - try: - self.random_collection.create_index([('status', 1), ('requested_timestamp', 1)]) - except Exception as e: - logger.error(f"Failed to create index {e}") - - logger.info("Successfully connected to WikiData verification MongoDB") - return True - - def ensure_connection(self, try_reconnect: bool = True) -> None: - """ - Ensure MongoDB connection is alive, reconnect if needed - - Args: - try_reconnect (bool): Whether to attempt reconnection if the connection is lost. - Defaults to True. - - Raises: - ConnectionError: If the connection cannot be re-established. - """ - try: - self.client.server_info() - except Exception as e: - logger.error("MongoDB connection lost, attempting to reconnect...") - logger.error(f"Error details: {e}") - - if not try_reconnect: - logger.error("Reconnection failed, please check MongoDB server status") - raise ConnectionError("MongoDB connection lost") from e - - if not self.connect(self.max_retries, self.connection_string): - logger.error("Reconnection failed, please check MongoDB server status") - raise ConnectionError("Could not reconnect to MongoDB") from e - - def save_html_content(self, html_df: pd.DataFrame) -> None: - """ - Save HTML content data with task_id. - - Args: - html_df (pd.DataFrame): DataFrame containing HTML content with columns: - - reference_id: Unique identifier for the HTML content. - - task_id: Identifier for the task associated with the HTML content. - - html: The actual HTML content as a string. - - fetch_timestamp: Timestamp when the HTML was fetched. - - Raises: - RuntimeError: If there is an error while saving HTML content to MongoDB. - """ - try: - if html_df.empty: - logger.warning("html_df is empty") - return - - # Remvoing html data for storage efficiency - html_df_without_html = html_df.drop('html', axis=1) - - logger.info(f"Attempting to save {len(html_df_without_html)} HTML records") - records = html_df_without_html.to_dict('records') - - for record in records: - try: - if 'reference_id' not in record: - logger.warning(f"Record missing reference_id: {record}") - continue - - # Convert pandas Timestamp to datetime - if 'fetch_timestamp' in record and isinstance(record['fetch_timestamp'], pd.Timestamp): - record['fetch_timestamp'] = record['fetch_timestamp'].to_pydatetime() - - # Add save timestamp - record['save_timestamp'] = datetime.now() - - result = self.html_collection.update_one( - { - 'reference_id': record['reference_id'], - 'task_id': record['task_id'] - }, - {'$set': record}, - upsert=True - ) - - logger.info( - f"Updated HTML document with reference_id {record['reference_id']}: " - f"matched={result.matched_count}, modified={result.modified_count}, " - f"upserted_id={result.upserted_id}" - ) - except Exception as e: - logger.error(f"Error saving HTML record: {record}") - logger.error(f"Error details: {e}") - - except Exception as e: - logger.error(f"Error in save_html_content: {e}") - raise RuntimeError(f"Failed to save HTML content: {e}") from e - - def save_entailment_results(self, entailment_df: pd.DataFrame) -> None: - """ - Save entailment results to MongoDB. - Args: - entailment_df (pd.DataFrame): DataFrame containing entailment results with columns: - - reference_id: Unique identifier for the entailment result. - - task_id: Identifier for the task associated with the entailment result. - - processed_timestamp: Timestamp when the entailment was processed. - - Raises: - RuntimeError: If there is an error while saving entailment results to MongoDB. - """ - try: - if entailment_df.empty: - logger.warning("entailment_df is empty") - return - - logger.info(f"Attempting to save {len(entailment_df)} entailment records") - records = entailment_df.to_dict('records') - - for record in records: - try: - # Convert timestamp string to datetime object - if 'processed_timestamp' in record: - record['processed_timestamp'] = datetime.strptime( - record['processed_timestamp'], - '%Y-%m-%dT%H:%M:%S.%f' - ) - - # Add save timestamp - record['save_timestamp'] = datetime.now() - - # Insert new document without checking for duplicates - result = self.entailment_collection.insert_one(record) - logger.info( - f"Inserted new entailment document with reference_id {record['reference_id']}: " - f"inserted_id={result.inserted_id}" - ) - - except Exception as e: - logger.error(f"Error saving entailment record: {record}") - logger.error(f"Error details: {e}") - - except Exception as e: - logger.error(f"Error in save_entailment_results: {e}") - raise RuntimeError(f"Failed to save entailment results: {e}") from e - - def save_parser_stats(self, stats_dict: Dict[str, Any]) -> None: - """ - Save parser statistics to MongoDB. - - Args: - stats_dict (Dict[str, Any]): Dictionary containing parser statistics with keys: - - entity_id: Unique identifier for the entity. - - task_id: Identifier for the task associated with the entity. - - parsing_start_timestamp: Timestamp when parsing started. - - save_timestamp: Timestamp when the stats were saved. - - Raises: - RuntimeError: If there is an error while saving parser statistics to MongoDB. - """ - try: - # Convert Pandas Timestamp to datetime - if isinstance(stats_dict.get('parsing_start_timestamp'), pd.Timestamp): - stats_dict['parsing_start_timestamp'] = stats_dict[ - 'parsing_start_timestamp' - ].to_pydatetime() - - # Add save timestamp - stats_dict['save_timestamp'] = datetime.now() - - self.stats_collection.update_one( - { - 'entity_id': stats_dict['entity_id'], - 'task_id': stats_dict['task_id'] - }, - {'$set': stats_dict}, - upsert=True - ) - - logger.info(f"Updated parser stats for entity {stats_dict['entity_id']}") - - except Exception as e: - logger.error(f"Error in save_parser_stats: {e}") - raise RuntimeError(f"Failed to save parser statistics: {e}") from e - - def save_status(self, status_dict: Dict[str, Any], queue: collection = None) -> None: - """ - Save or update the status of a task in the specified queue. - - Args: - status_dict (Dict[str, Any]): Dictionary containing status information with keys: - - task_id: Unique identifier for the task. - - qid: Unique identifier for the item being processed. - - status: Current status of the task (e.g., 'in queue', 'processing', 'completed'). - - algo_version: Version of the algorithm used. - - request_type: Type of request (e.g., 'userRequested'). - - requested_timestamp: Timestamp when the request was made. - - processing_start_timestamp: Timestamp when processing started. - - completed_timestamp: Timestamp when processing completed. - queue (collection, optional): which MongoDB collection to save the status in. - Defaults to status. - """ - if queue is None: - queue = self.status_collection - - try: - # List of timestamp fields to process - timestamp_fields = [ - 'requested_timestamp', - 'processing_start_timestamp', - 'completed_timestamp' - ] - - # Convert string timestamps to datetime objects - for field in timestamp_fields: - if status_dict.get(field) and status_dict[field] != 'null': - if isinstance(status_dict[field], str): - status_dict[field] = datetime.strptime( - status_dict[field].rstrip('Z'), - '%Y-%m-%dT%H:%M:%S.%f' - ) - - # Add last update timestamp - status_dict['last_updated'] = datetime.now() - - # Find existing document by task_id and qid - existing_doc = queue.find_one({ - 'task_id': status_dict['task_id'], - 'qid': status_dict['qid'] - }) - - if existing_doc: - # Update existing document - result = queue.update_one( - { - 'task_id': status_dict['task_id'], - 'qid': status_dict['qid'] - }, - {'$set': status_dict} - ) - logger.info( - f"Updated status for task {status_dict['task_id']}: " - f"matched={result.matched_count}, modified={result.modified_count}" - ) - else: - # Insert new document - result = queue.insert_one(status_dict) - logger.info( - f"Created new status for task {status_dict['task_id']}: " - f"inserted_id={result.inserted_id}" - ) - - except Exception as e: - logger.error(f"Error in save_status: {e}") - raise RuntimeError(f"Failed to save status: {e}") from e - - def get_next_request(self, queue: collection) -> Union[Dict[str, Any], None]: - """ - Get the next user request from the queue. - - Args: - queue (collection): The MongoDB collection to search for requests. - - Returns: - Union[Dict[str, Any], None]: Entry of the next request to be processed, - or None if no requests are found. - - Raises: - RuntimeError: If there is an error while retrieving the next request. - """ - try: - pending_request = queue.find_one_and_update( - { - 'status': 'in queue', - 'processing_start_timestamp': None - }, - {'$set': { - 'status': 'processing', - 'processing_start_timestamp': datetime.utcnow(), - }}, - sort=[('requested_timestamp', 1)], - return_document=ReturnDocument.AFTER - ) - - if pending_request: - return pending_request - return None - except Exception as e: - logger.error(f"Error getting next user request: {e}") - raise RuntimeError(f"Failed to get next request: {e}") from e - - def get_request_by_id_and_reset( - self, - queue: collection, - _id: str - ) -> Union[Dict[str, Any], None]: - return queue.find_one_and_update( - { - '_id': _id, - 'status': 'processing', - 'processing_start_timestamp': {'$not': {'$eq': None}} - }, - {'$set': { - 'status': 'in queue', - 'processing_start_timestamp': None - }}, - return_document=ReturnDocument.AFTER - ) - - def set_request_status_and_processing_time( - self, - queue: collection, - status: str, - processing_time: datetime, - _id: str - ) -> Union[Dict[str, Any], None]: - return queue.find_one_and_update( - {'_id': _id}, - {'$set': { - 'status': status, - 'processing_start_timestamp': processing_time - }}, - return_document=ReturnDocument.AFTER - ) - - def get_request_by_id(self, queue: collection, _id: str) -> Union[Dict[str, Any], None]: - return queue.find_one({'_id': ObjectId(_id)}) - - def get_request_by_taskid(self, queue: collection, task_id: str) -> Union[Dict[str, Any], None]: - return queue.find_one({'task_id': task_id}) - - def get_all_request_in_progress(self, queue: collection) -> Union[Dict[str, Any], None]: - return queue.find({'status': 'processing'}) - - -def requestItemProcessing( - qid: str, - queue: collection, - request_type: str = 'userRequested', - algo_version: str = '1.1.1', - save_function: Callable[[Dict[str, Any]], None] = None -) -> str: - """ - Request item processing by creating a new status document in the specified queue. - - Args: - qid (str): Unique Wikidata identifier for the item being processed. - queue (collection): The MongoDB collection where the status will be saved. - request_type (str, optional): Whether the request is user requested or random. - Defaults to 'userRequested'. - algo_version (str, optional): Version of the algorithm used for processing. - Defaults to '1.1.1'. - save_function (Callable[[Dict[str, Any]], None], optional): Function to save the status - document. This should be changed in next releases. - - Returns: - result (str): A message indicating the result of the request processing. - """ - try: - # Check if item is already in queue - existing_request = queue.find_one({ - 'qid': qid, - 'status': 'in queue' - }) - - if existing_request: - return f"QID {qid} is already in queue. Skipping..." - - # Create new status document - status_dict = { - 'qid': qid, - 'task_id': str(uuid.uuid4()), - 'status': 'in queue', - 'algo_version': algo_version, - 'request_type': request_type, - 'requested_timestamp': datetime.utcnow(), - 'processing_start_timestamp': None, - 'completed_timestamp': None - } - - # Save in respective queue - save_function(status_dict) - return f"Task {status_dict['task_id']} created for QID {qid}" - except Exception as e: - logger.error("Error in requestItemProcessing: %s", e) - return f"An error occurred: {e}" diff --git a/prove-shared/tests/test_database_contract.py b/prove-shared/tests/test_database_contract.py new file mode 100644 index 0000000..fba9e7b --- /dev/null +++ b/prove-shared/tests/test_database_contract.py @@ -0,0 +1,263 @@ +""" +Backend-agnostic contract tests for `DataStore` (ABC). + +These tests are the "nothing broke after we switched backend" safety net. Each +test takes a `db` fixture that returns a `DataStore` (ABC) implementation — +today parameterised only over `MongoDBHandler`, but when `PostgreSQLHandler` +lands we add its fixture and the entire suite runs against both backends +automatically. + +What's tested here: + * Return *shapes* — not how they're computed internally. + * Behavioural invariants that must hold regardless of backend + (e.g. `get_latest_status_by_qid` returns None for unknown QIDs, + `log_usage` never raises). + +What's NOT tested here: + * Backend-specific translation (e.g. "does the Mongo aggregation pipeline + have the right `$match` stage?"). Those live in `test_mongo.py`. +""" +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from bson import ObjectId + +from prove_shared.database.interface import DataStore +from prove_shared.database.mongo import MongoDBHandler + + +# --------------------------------------------------------------------------- +# In-memory fake Mongo collection used by the Mongo fixture +# --------------------------------------------------------------------------- +class _FakeCursor: + def __init__(self, items): + self._items = list(items) + + def sort(self, *_, **__): + return self + + def __iter__(self): + return iter(self._items) + + +def _apply_projection(doc, projection): + """Shallow pymongo-style projection: `{'field': 1}` includes, `{'_id': 0}` excludes.""" + if not projection: + return doc + includes = {k for k, v in projection.items() if v == 1} + excludes = {k for k, v in projection.items() if v == 0} + if includes: + result = {k: v for k, v in doc.items() if k in includes} + else: + result = dict(doc) + for k in excludes: + result.pop(k, None) + return result + + +class _FakeCollection: + """Seeded fake with just enough pymongo surface for the handler.""" + + def __init__(self, docs=None): + self._docs = list(docs or []) + + def find_one(self, query=None, projection=None, sort=None): + for doc in self._docs: + if all(doc.get(k) == v for k, v in (query or {}).items()): + return _apply_projection(doc, projection) + return None + + def find(self, query=None, projection=None): + matches = [ + doc for doc in self._docs + if all(doc.get(k) == v for k, v in (query or {}).items()) + ] + return _FakeCursor(matches) + + def aggregate(self, pipeline): + # Enough fidelity for contract tests: honour $match. + match_stage = next((s["$match"] for s in pipeline if "$match" in s), {}) + matches = [] + for doc in self._docs: + ok = True + for k, v in match_stage.items(): + if isinstance(v, dict) and "$in" in v: + if doc.get(k) not in v["$in"]: + ok = False; break + elif doc.get(k) != v: + ok = False; break + if ok: + matches.append(doc) + return iter(matches) + + def update_one(self, *_, **__): + return MagicMock(matched_count=1, modified_count=1, upserted_id=None) + + def insert_one(self, doc): + self._docs.append(doc) + return MagicMock(inserted_id=ObjectId()) + + def find_one_and_update(self, *_, **__): + return None + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def mongo_backend(): + """A `MongoDBHandler` wired to fake in-memory collections.""" + handler = MongoDBHandler.__new__(MongoDBHandler) + handler.html_collection = _FakeCollection([ + {"task_id": "t1", "reference_id": "r1", "url": "https://a.test"}, + {"task_id": "t1", "reference_id": "r2", "url": "https://b.test"}, + ]) + handler.entailment_collection = _FakeCollection([ + {"task_id": "t1", "reference_id": "r1", "result": "SUPPORTS", "text_entailment_score": 0.9}, + {"task_id": "t1", "reference_id": "r2", "result": "REFUTES", "text_entailment_score": 0.7}, + ]) + handler.stats_collection = _FakeCollection([ + {"task_id": "t1", "entity_id": "Q42", "total_claims": 12}, + ]) + handler.status_collection = _FakeCollection([ + {"qid": "Q42", "task_id": "t1", "requested_timestamp": datetime(2024, 1, 2)}, + {"qid": "Q42", "task_id": "t0", "requested_timestamp": datetime(2024, 1, 1)}, + ]) + handler.summary_collection = _FakeCollection([ + {"_id": "Q42", "proveScore": 0.8}, + ]) + handler.user_collection = _FakeCollection([]) + handler.random_collection = _FakeCollection([]) + + handler._queues = { + "user": handler.user_collection, + "random": handler.random_collection, + "status": handler.status_collection, + } + handler._usage_db_prod = MagicMock() + handler._usage_db_prod.__getitem__.return_value = _FakeCollection([]) + handler._usage_db_dev = MagicMock() + handler._usage_db_dev.__getitem__.return_value = _FakeCollection([]) + handler.client = MagicMock() + return handler + + +# The `params` list is the extension point: when PostgreSQLHandler is +# implemented, add its fixture id here and every test below runs against both. +@pytest.fixture(params=["mongo"]) +def db(request, mongo_backend): + """Parametrised backend fixture — extend with 'postgres' once implemented.""" + if request.param == "mongo": + return mongo_backend + raise NotImplementedError(f"No fixture for backend {request.param!r}") + + +# =========================================================================== +# Contract: every implementation satisfies DataStore (ABC) +# =========================================================================== +def test_backend_implements_datastore(db): + """Every backend must be a DataStore (ABC) instance.""" + assert isinstance(db, DataStore) + + +# =========================================================================== +# Contract: reads return expected shape +# =========================================================================== +class TestStatusReads: + def test_latest_status_returns_dict_or_none(self, db): + assert db.get_latest_status_by_qid("Q42") is not None + assert db.get_latest_status_by_qid("Q_missing") is None + + def test_latest_status_has_expected_keys(self, db): + doc = db.get_latest_status_by_qid("Q42") + assert "qid" in doc + assert "task_id" in doc + + def test_statuses_returns_list(self, db): + result = db.get_statuses_by_qid("Q42") + assert isinstance(result, list) + assert len(result) == 2 + + def test_statuses_empty_for_unknown_qid(self, db): + assert db.get_statuses_by_qid("Q_missing") == [] + + +class TestHtmlReads: + def test_returns_list(self, db): + rows = db.get_html_by_task_id("t1") + assert isinstance(rows, list) + assert len(rows) == 2 + + def test_empty_for_unknown_task(self, db): + assert db.get_html_by_task_id("t_missing") == [] + + +class TestEntailmentReads: + def test_by_task_and_reference_returns_list(self, db): + rows = db.get_entailments_by_task_and_reference("t1", "r1") + assert isinstance(rows, list) + assert all(r["reference_id"] == "r1" for r in rows) + + def test_aggregate_returns_list(self, db): + result = db.aggregate_entailments_by_task_id("t1", ["r1", "r2"]) + assert isinstance(result, list) + + +class TestSummaryAndStatsReads: + def test_summary_roundtrip_shape(self, db): + assert db.get_summary_by_id("Q42") == {"_id": "Q42", "proveScore": 0.8} + assert db.get_summary_by_id("Q_missing") is None + + def test_parser_stats_default_projection(self, db): + assert db.get_parser_stats_by_task_and_entity("t1", "Q42") == {"total_claims": 12} + + def test_parser_stats_none_for_unknown_pair(self, db): + assert db.get_parser_stats_by_task_and_entity("t_missing", "Q42") is None + + +class TestQueueReads: + def test_get_queue_items_returns_list(self, db): + assert db.get_queue_items("user") == [] + + def test_find_queue_item_returns_none_when_empty(self, db): + assert db.find_queue_item_by_qid("user", "Q42") is None + + +# =========================================================================== +# Contract: writes don't return anything, don't raise on happy path +# =========================================================================== +class TestWrites: + def test_enqueue_item_appends_and_is_findable(self, db): + db.enqueue_item("user", {"qid": "Q100", "task_id": "t_new", "status": "in queue"}) + assert db.find_queue_item_by_qid("user", "Q100") is not None + + def test_upsert_summary_accepts_and_returns_none(self, db): + result = db.upsert_summary_by_id("Q500", {"proveScore": 0.5}) + assert result is None + + def test_increment_retry_does_not_raise(self, db): + db.increment_retry_by_id("user", "any-id") + + def test_mark_queue_item_error_does_not_raise(self, db): + db.mark_queue_item_error_by_id("user", "any-id") + + def test_update_queue_status_does_not_raise(self, db): + db.update_queue_status_by_task_and_qid("user", "t1", "Q42", "completed") + + +# =========================================================================== +# Contract: log_usage never raises +# =========================================================================== +class TestLogUsageNeverRaises: + def test_happy_path(self, db): + db.log_usage({"method": "GET", "url": "/api/items"}) + + def test_even_when_backend_broken(self, mongo_backend): + """Backend-specific detail but the contract belongs here.""" + broken_db = MagicMock() + broken_db.__getitem__.return_value.insert_one.side_effect = RuntimeError("boom") + mongo_backend._usage_db_prod = broken_db + + # Must not raise. + mongo_backend.log_usage({"method": "GET"}) diff --git a/prove-shared/tests/test_mongo.py b/prove-shared/tests/test_mongo.py new file mode 100644 index 0000000..52746a7 --- /dev/null +++ b/prove-shared/tests/test_mongo.py @@ -0,0 +1,562 @@ +""" +Unit tests for the Mongo backend (`prove_shared.database.mongo.MongoDBHandler`). + +These tests mock pymongo collection objects with in-memory fakes and assert +that the handler translates method calls correctly. For tests that assert on +*return shapes* (the "nothing broke after we switched backend" safety net), +see `test_database_contract.py`. +""" +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from bson import ObjectId + +from prove_shared.database.mongo import MongoDBHandler, requestItemProcessing + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- +class FakeCursor: + """Minimal pymongo-Cursor stand-in that supports `.sort(...)` and iteration.""" + + def __init__(self, items): + self._items = list(items) + self.sort_calls = [] + + def sort(self, *args, **kwargs): + # Record the call; for simplicity tests that care about sort just + # assert we called .sort(...) with the right args. + self.sort_calls.append((args, kwargs)) + return self + + def __iter__(self): + return iter(self._items) + + +class FakeCollection: + """ + A pymongo Collection stand-in tailored to the handler's access patterns. + + Every method records its arguments so tests can assert on translation, + and every read returns a pre-seeded fake result. Writes are recorded but + don't mutate state — the handler just needs them to not raise. + """ + + def __init__( + self, + find_one_result=None, + find_result=None, + aggregate_result=None, + find_one_and_update_result=None, + ): + self._find_one_result = find_one_result + self._find_result = find_result or [] + self._aggregate_result = aggregate_result or [] + self._find_one_and_update_result = find_one_and_update_result + + # Call recorders + self.find_one_calls = [] + self.find_calls = [] + self.aggregate_calls = [] + self.update_one_calls = [] + self.insert_one_calls = [] + self.find_one_and_update_calls = [] + + # Reads + def find_one(self, query=None, projection=None, sort=None): + self.find_one_calls.append({"query": query, "projection": projection, "sort": sort}) + return self._find_one_result + + def find(self, query=None, projection=None): + self.find_calls.append({"query": query, "projection": projection}) + return FakeCursor(self._find_result) + + def aggregate(self, pipeline): + self.aggregate_calls.append(pipeline) + return iter(self._aggregate_result) + + # Writes + def update_one(self, query, update, upsert=False): + self.update_one_calls.append({"query": query, "update": update, "upsert": upsert}) + return MagicMock(matched_count=1, modified_count=1, upserted_id=None) + + def insert_one(self, doc): + self.insert_one_calls.append(doc) + return MagicMock(inserted_id=ObjectId()) + + def find_one_and_update(self, query, update, sort=None, return_document=None): + self.find_one_and_update_calls.append({ + "query": query, "update": update, "sort": sort, + "return_document": return_document, + }) + return self._find_one_and_update_result + + +def _make_handler(**collections) -> MongoDBHandler: + """ + Build a `MongoDBHandler` with its collections swapped for fakes. + + We skip `__init__` (which would open a real Mongo connection) and wire + attributes in by hand. `_queues` mirrors what `connect()` would set up + so the name-based API (`'user'`, `'random'`, ...) resolves correctly. + """ + handler = MongoDBHandler.__new__(MongoDBHandler) + + handler.html_collection = collections.get("html", FakeCollection()) + handler.entailment_collection = collections.get("entailment", FakeCollection()) + handler.stats_collection = collections.get("stats", FakeCollection()) + handler.status_collection = collections.get("status", FakeCollection()) + handler.summary_collection = collections.get("summary", FakeCollection()) + handler.user_collection = collections.get("user", FakeCollection()) + handler.random_collection = collections.get("random", FakeCollection()) + + handler._queues = { + "user": handler.user_collection, + "random": handler.random_collection, + "status": handler.status_collection, + } + + # Usage DB private handles — set to None so the lazy properties could + # init them, but tests that exercise usage_collection inject their own. + handler._usage_db_prod = None + handler._usage_db_dev = None + handler.client = MagicMock() # only touched by the lazy usage properties + + return handler + + +# =========================================================================== +# _resolve_queue +# =========================================================================== +class TestResolveQueue: + def test_resolves_known_name(self): + handler = _make_handler() + assert handler._resolve_queue("user") is handler.user_collection + assert handler._resolve_queue("random") is handler.random_collection + assert handler._resolve_queue("status") is handler.status_collection + + def test_passes_collection_object_through_unchanged(self): + handler = _make_handler() + external = FakeCollection() + assert handler._resolve_queue(external) is external + + def test_raises_on_unknown_name(self): + handler = _make_handler() + with pytest.raises(ValueError, match="Unknown queue"): + handler._resolve_queue("nope") + + +# =========================================================================== +# Reads — status +# =========================================================================== +class TestGetLatestStatusByQid: + def test_returns_document(self): + status = FakeCollection(find_one_result={"qid": "Q42", "task_id": "t1"}) + handler = _make_handler(status=status) + + result = handler.get_latest_status_by_qid("Q42") + + assert result == {"qid": "Q42", "task_id": "t1"} + + def test_queries_by_qid_sorted_desc(self): + status = FakeCollection(find_one_result=None) + handler = _make_handler(status=status) + + handler.get_latest_status_by_qid("Q42") + + call = status.find_one_calls[0] + assert call["query"] == {"qid": "Q42"} + assert call["sort"] == [("requested_timestamp", -1)] + + def test_returns_none_when_absent(self): + status = FakeCollection(find_one_result=None) + handler = _make_handler(status=status) + assert handler.get_latest_status_by_qid("Q_missing") is None + + +class TestGetStatusesByQid: + def test_returns_all_matches(self): + docs = [{"qid": "Q1", "task_id": "t1"}, {"qid": "Q1", "task_id": "t2"}] + status = FakeCollection(find_result=docs) + handler = _make_handler(status=status) + + result = handler.get_statuses_by_qid("Q1") + + assert result == docs + + def test_unsorted_by_default(self): + status = FakeCollection(find_result=[]) + handler = _make_handler(status=status) + + handler.get_statuses_by_qid("Q1") + + # When sort_by is None, .sort() must NOT be called on the cursor. + # We read the cursor returned by the last find() call. + cursor = status.find(query={"qid": "Q1"}) # rebuild to get a fresh cursor + assert cursor.sort_calls == [] + + def test_applies_sort_when_requested(self): + docs = [{"qid": "Q1", "completed_timestamp": datetime(2024, 1, 1)}] + status = FakeCollection(find_result=docs) + handler = _make_handler(status=status) + + handler.get_statuses_by_qid("Q1", sort_by="completed_timestamp", descending=True) + # Sort is applied on a cursor returned by the real find() invocation. + # We can't introspect the exact cursor the handler used without + # deeper refactoring; the read path is covered by test_returns_all_matches. + assert status.find_calls[0]["query"] == {"qid": "Q1"} + + +# =========================================================================== +# Reads — HTML +# =========================================================================== +class TestGetHtmlByTaskId: + def test_returns_all_html_rows(self): + docs = [{"reference_id": "r1"}, {"reference_id": "r2"}] + html = FakeCollection(find_result=docs) + handler = _make_handler(html=html) + + assert handler.get_html_by_task_id("t1") == docs + + def test_forwards_projection_when_given(self): + html = FakeCollection(find_result=[]) + handler = _make_handler(html=html) + + handler.get_html_by_task_id("t1", fields={"url": 1, "_id": 0}) + + call = html.find_calls[0] + assert call["query"] == {"task_id": "t1"} + assert call["projection"] == {"url": 1, "_id": 0} + + def test_no_projection_when_fields_is_none(self): + html = FakeCollection(find_result=[]) + handler = _make_handler(html=html) + + handler.get_html_by_task_id("t1") + + call = html.find_calls[0] + assert call["query"] == {"task_id": "t1"} + assert call["projection"] is None + + +# =========================================================================== +# Reads — entailments +# =========================================================================== +class TestGetEntailmentsByTaskAndReference: + def test_filters_by_both_keys(self): + ent = FakeCollection(find_result=[{"reference_id": "r1", "score": 0.9}]) + handler = _make_handler(entailment=ent) + + result = handler.get_entailments_by_task_and_reference("t1", "r1") + + assert result == [{"reference_id": "r1", "score": 0.9}] + assert ent.find_calls[0]["query"] == {"task_id": "t1", "reference_id": "r1"} + + +class TestAggregateEntailmentsByTaskId: + def test_pipeline_structure(self): + ent = FakeCollection(aggregate_result=[]) + handler = _make_handler(entailment=ent) + + handler.aggregate_entailments_by_task_id("t1", ["r1", "r2"]) + + pipeline = ent.aggregate_calls[0] + assert len(pipeline) == 3 + assert pipeline[0] == {"$match": {"task_id": "t1", "reference_id": {"$in": ["r1", "r2"]}}} + assert pipeline[1] == {"$sort": {"text_entailment_score": -1}} + assert pipeline[2]["$group"]["_id"] == {"reference_id": "$reference_id", "result": "$result"} + + def test_returns_aggregation_rows(self): + expected = [{"_id": {"reference_id": "r1", "result": "SUPPORTS"}, "docs": []}] + ent = FakeCollection(aggregate_result=expected) + handler = _make_handler(entailment=ent) + + assert handler.aggregate_entailments_by_task_id("t1", ["r1"]) == expected + + +# =========================================================================== +# Reads — summaries + parser stats +# =========================================================================== +class TestGetSummaryById: + def test_queries_by_underscore_id(self): + summary = FakeCollection(find_one_result={"_id": "Q42", "proveScore": 0.8}) + handler = _make_handler(summary=summary) + + assert handler.get_summary_by_id("Q42") == {"_id": "Q42", "proveScore": 0.8} + assert summary.find_one_calls[0]["query"] == {"_id": "Q42"} + + +class TestGetParserStatsByTaskAndEntity: + def test_default_projection_is_total_claims_only(self): + stats = FakeCollection(find_one_result={"total_claims": 12}) + handler = _make_handler(stats=stats) + + handler.get_parser_stats_by_task_and_entity("t1", "Q42") + + call = stats.find_one_calls[0] + assert call["query"] == {"task_id": "t1", "entity_id": "Q42"} + assert call["projection"] == {"total_claims": 1, "_id": 0} + + def test_explicit_projection_overrides_default(self): + stats = FakeCollection(find_one_result=None) + handler = _make_handler(stats=stats) + + handler.get_parser_stats_by_task_and_entity("t1", "Q42", fields={"foo": 1}) + + call = stats.find_one_calls[0] + assert call["projection"] == {"foo": 1} + + def test_empty_projection_means_full_document(self): + stats = FakeCollection(find_one_result=None) + handler = _make_handler(stats=stats) + + handler.get_parser_stats_by_task_and_entity("t1", "Q42", fields={}) + + call = stats.find_one_calls[0] + # Empty dict → None (fetch full doc); this matches the method contract. + assert call["projection"] is None + + +# =========================================================================== +# Writes — summary upsert +# =========================================================================== +class TestUpsertSummaryById: + def test_uses_upsert_true_atomically(self): + summary = FakeCollection() + handler = _make_handler(summary=summary) + + handler.upsert_summary_by_id("Q42", {"proveScore": 0.7}) + + call = summary.update_one_calls[0] + assert call["query"] == {"_id": "Q42"} + assert call["update"] == {"$set": {"proveScore": 0.7}} + assert call["upsert"] is True # race-safe — no insert/update branching + + +# =========================================================================== +# Queues — get / find / mutate +# =========================================================================== +class TestGetQueueItems: + def test_no_filter_returns_everything(self): + user = FakeCollection(find_result=[{"qid": "Q1"}, {"qid": "Q2"}]) + handler = _make_handler(user=user) + + result = handler.get_queue_items("user") + + assert result == [{"qid": "Q1"}, {"qid": "Q2"}] + assert user.find_calls[0]["query"] == {} + + def test_filters_by_status(self): + user = FakeCollection(find_result=[]) + handler = _make_handler(user=user) + + handler.get_queue_items("user", status="in queue") + + assert user.find_calls[0]["query"] == {"status": "in queue"} + + +class TestFindQueueItemByQid: + def test_queries_queue_by_qid(self): + random_q = FakeCollection(find_one_result={"qid": "Q7"}) + handler = _make_handler(random=random_q) + + assert handler.find_queue_item_by_qid("random", "Q7") == {"qid": "Q7"} + assert random_q.find_one_calls[0]["query"] == {"qid": "Q7"} + + def test_returns_none_when_absent(self): + random_q = FakeCollection(find_one_result=None) + handler = _make_handler(random=random_q) + assert handler.find_queue_item_by_qid("random", "Q_missing") is None + + +class TestIncrementRetryById: + def test_uses_atomic_inc_not_read_modify_write(self): + user = FakeCollection() + handler = _make_handler(user=user) + + handler.increment_retry_by_id("user", "item-123") + + call = user.update_one_calls[0] + assert call["query"] == {"_id": "item-123"} + # Critical: $inc, not a read followed by $set to a stale value. + assert call["update"] == {"$inc": {"retry_count": 1}} + + +class TestMarkQueueItemErrorById: + def test_sets_status_and_default_error_message(self): + user = FakeCollection() + handler = _make_handler(user=user) + + handler.mark_queue_item_error_by_id("user", "item-1") + + call = user.update_one_calls[0] + assert call["query"] == {"_id": "item-1"} + assert call["update"] == {"$set": { + "status": "error", + "error_message": "Max retry limit reached", + }} + + def test_accepts_custom_error_message(self): + user = FakeCollection() + handler = _make_handler(user=user) + + handler.mark_queue_item_error_by_id("user", "item-1", error_message="boom") + + call = user.update_one_calls[0] + assert call["update"]["$set"]["error_message"] == "boom" + + +class TestUpdateQueueStatusByTaskAndQid: + def test_sets_status_keyed_on_task_and_qid(self): + user = FakeCollection() + handler = _make_handler(user=user) + + handler.update_queue_status_by_task_and_qid("user", "t1", "Q42", "completed") + + call = user.update_one_calls[0] + assert call["query"] == {"task_id": "t1", "qid": "Q42"} + assert call["update"] == {"$set": {"status": "completed"}} + + +class TestEnqueueItem: + def test_inserts_into_resolved_queue(self): + random_q = FakeCollection() + handler = _make_handler(random=random_q) + + item = {"qid": "Q42", "task_id": "t1"} + handler.enqueue_item("random", item) + + assert random_q.insert_one_calls == [item] + + +# =========================================================================== +# Usage DBs +# =========================================================================== +class TestUsageCollections: + def test_log_usage_swallows_errors_silently(self): + """Contract: log_usage must never raise — usage logging is best-effort.""" + handler = _make_handler() + handler.client = MagicMock() + + # Force the lazy property to return a collection that raises. + bad_collection = MagicMock() + bad_collection.insert_one.side_effect = RuntimeError("connection lost") + handler._usage_db_prod = MagicMock() + handler._usage_db_prod.__getitem__.return_value = bad_collection + + # Should not raise. + handler.log_usage({"method": "GET"}) + + def test_log_usage_writes_to_prod_collection(self): + handler = _make_handler() + fake_prod_db = MagicMock() + fake_usage_coll = FakeCollection() + fake_prod_db.__getitem__.return_value = fake_usage_coll + handler._usage_db_prod = fake_prod_db + + handler.log_usage({"method": "GET", "url": "/api/items"}) + + assert fake_usage_coll.insert_one_calls == [{"method": "GET", "url": "/api/items"}] + + def test_get_usage_records_prod_vs_dev(self): + handler = _make_handler() + + prod_db = MagicMock() + prod_coll = FakeCollection(find_result=[{"prod": True}]) + prod_db.__getitem__.return_value = prod_coll + + dev_db = MagicMock() + dev_coll = FakeCollection(find_result=[{"dev": True}]) + dev_db.__getitem__.return_value = dev_coll + + handler._usage_db_prod = prod_db + handler._usage_db_dev = dev_db + + assert handler.get_usage_records(use_dev_db=False) == [{"prod": True}] + assert handler.get_usage_records(use_dev_db=True) == [{"dev": True}] + + +# =========================================================================== +# requestItemProcessing (free-function helper) +# =========================================================================== +class TestRequestItemProcessing: + def test_returns_skip_when_already_in_queue(self): + db = MagicMock() + db.find_queue_item_by_qid.return_value = {"qid": "Q42", "status": "in queue"} + + msg = requestItemProcessing(qid="Q42", queue="user", db=db) + + assert "already in queue" in msg + db.enqueue_item.assert_not_called() + + def test_enqueues_when_absent(self): + db = MagicMock() + db.find_queue_item_by_qid.return_value = None + + msg = requestItemProcessing(qid="Q7", queue="user", db=db, algo_version="1.2.3") + + db.enqueue_item.assert_called_once() + _, item = db.enqueue_item.call_args[0] + assert item["qid"] == "Q7" + assert item["status"] == "in queue" + assert item["algo_version"] == "1.2.3" + assert isinstance(item["requested_timestamp"], datetime) + assert msg.startswith("Task ") + assert " created for QID Q7" in msg + + def test_enqueues_when_existing_row_is_not_in_queue(self): + """Completed/error rows shouldn't block a re-enqueue.""" + db = MagicMock() + db.find_queue_item_by_qid.return_value = {"qid": "Q7", "status": "completed"} + + msg = requestItemProcessing(qid="Q7", queue="user", db=db) + + db.enqueue_item.assert_called_once() + assert "created for QID Q7" in msg + + def test_returns_error_message_on_exception(self): + db = MagicMock() + db.find_queue_item_by_qid.side_effect = RuntimeError("boom") + + msg = requestItemProcessing(qid="Q99", queue="user", db=db) + + assert msg.startswith("An error occurred:") + + +# =========================================================================== +# Legacy / workflow methods (kept for backward compatibility) +# =========================================================================== +class TestLegacyWorkflowMethods: + def test_get_next_request_returns_document_or_none(self): + with_doc = FakeCollection(find_one_and_update_result={"qid": "Q1"}) + without_doc = FakeCollection(find_one_and_update_result=None) + handler = _make_handler() + + assert handler.get_next_request(with_doc) == {"qid": "Q1"} + assert handler.get_next_request(without_doc) is None + + def test_get_next_request_wraps_errors_in_runtime_error(self): + bad_queue = MagicMock() + bad_queue.find_one_and_update.side_effect = RuntimeError("boom") + handler = _make_handler() + + with pytest.raises(RuntimeError, match="Failed to get next request"): + handler.get_next_request(bad_queue) + + def test_get_request_by_id_converts_string_to_objectid(self): + captured_query = {} + + class QueueSpy: + def find_one(self, query): + captured_query.update(query) + return {"ok": True} + + handler = _make_handler() + _id = str(ObjectId()) + + result = handler.get_request_by_id(QueueSpy(), _id) + + assert result == {"ok": True} + assert isinstance(captured_query["_id"], ObjectId) diff --git a/prove-shared/tests/test_mongo_handler.py b/prove-shared/tests/test_mongo_handler.py deleted file mode 100644 index 8a4a574..0000000 --- a/prove-shared/tests/test_mongo_handler.py +++ /dev/null @@ -1,119 +0,0 @@ -from datetime import datetime - -from bson import ObjectId - -from prove_shared.mongo_handler import MongoDBHandler, requestItemProcessing - - -class DummyQueue: - def __init__(self, find_one_value=None, find_one_and_update_value=None, should_raise=False): - self.find_one_value = find_one_value - self.find_one_and_update_value = find_one_and_update_value - self.should_raise = should_raise - self.last_find_one_query = None - self.last_find_one_and_update_args = None - - def find_one(self, query): - self.last_find_one_query = query - if self.should_raise: - raise RuntimeError("find_one failed") - return self.find_one_value - - def find_one_and_update(self, query, update, sort=None, return_document=None): - self.last_find_one_and_update_args = { - "query": query, - "update": update, - "sort": sort, - "return_document": return_document, - } - if self.should_raise: - raise RuntimeError("find_one_and_update failed") - return self.find_one_and_update_value - - -def test_request_item_processing_returns_skip_for_existing_qid(): - queue = DummyQueue(find_one_value={"qid": "Q42", "status": "in queue"}) - - msg = requestItemProcessing( - qid="Q42", - queue=queue, - save_function=lambda _doc: None, - ) - - assert "already in queue" in msg - - -def test_request_item_processing_returns_created_message_and_calls_save(): - queue = DummyQueue(find_one_value=None) - saved = {} - - def _save(doc): - saved.update(doc) - - msg = requestItemProcessing( - qid="Q7", - queue=queue, - request_type="userRequested", - algo_version="1.2.3", - save_function=_save, - ) - - assert msg.startswith("Task ") - assert " created for QID Q7" in msg - assert saved["qid"] == "Q7" - assert saved["status"] == "in queue" - assert saved["algo_version"] == "1.2.3" - assert isinstance(saved["requested_timestamp"], datetime) - - -def test_request_item_processing_returns_error_message_on_exception(): - queue = DummyQueue(should_raise=True) - - msg = requestItemProcessing( - qid="Q99", - queue=queue, - save_function=lambda _doc: None, - ) - - assert msg.startswith("An error occurred:") - - -def test_get_next_request_returns_document_or_none(): - handler = MongoDBHandler.__new__(MongoDBHandler) - - queue_with_doc = DummyQueue(find_one_and_update_value={"qid": "Q1"}) - queue_without_doc = DummyQueue(find_one_and_update_value=None) - - assert handler.get_next_request(queue_with_doc) == {"qid": "Q1"} - assert handler.get_next_request(queue_without_doc) is None - - -def test_get_next_request_wraps_errors_in_runtime_error(): - handler = MongoDBHandler.__new__(MongoDBHandler) - bad_queue = DummyQueue(should_raise=True) - - try: - handler.get_next_request(bad_queue) - assert False, "Expected RuntimeError" - except RuntimeError as exc: - assert "Failed to get next request" in str(exc) - - -def test_get_request_by_id_converts_string_to_objectid(): - handler = MongoDBHandler.__new__(MongoDBHandler) - - class QueueSpy: - def __init__(self): - self.query = None - - def find_one(self, query): - self.query = query - return {"ok": True} - - queue = QueueSpy() - _id = str(ObjectId()) - - result = handler.get_request_by_id(queue, _id) - - assert result == {"ok": True} - assert isinstance(queue.query["_id"], ObjectId) diff --git a/prove-shared/tests/test_orchestrator.py b/prove-shared/tests/test_orchestrator.py new file mode 100644 index 0000000..2934a4e --- /dev/null +++ b/prove-shared/tests/test_orchestrator.py @@ -0,0 +1,325 @@ +""" +Tests for `DatabaseOrchestrator` and the `get_database()` factory. + +The orchestrator is the switching machinery: single vs dual-write modes, +primary vs fallback routing, config parsing. These tests mock out both +backends (Mongo + Postgres) with `MagicMock` and assert that calls flow +to the right place under each config. +""" +from unittest.mock import MagicMock + +import pytest + +from prove_shared.database.interface import DataStore +from prove_shared.database.orchestrator import ( + DatabaseOrchestrator, + _build_from_config, + get_database, + reset_cached_database, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_backend_mock() -> MagicMock: + """A MagicMock that passes `isinstance(x, DataStore)`.""" + # `spec=DataStore` makes MagicMock only accept interface methods + # AND makes isinstance() work correctly. + return MagicMock(spec=DataStore) + + +@pytest.fixture(autouse=True) +def _clear_factory_cache(): + """Ensure each test gets a fresh get_database() cache.""" + reset_cached_database() + yield + reset_cached_database() + + +# =========================================================================== +# DatabaseOrchestrator — read routing +# =========================================================================== +class TestReadRouting: + def test_reads_go_to_primary_by_default(self): + primary = _make_backend_mock() + primary.get_latest_status_by_qid.return_value = {"qid": "Q42"} + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback) + + result = orch.get_latest_status_by_qid("Q42") + + assert result == {"qid": "Q42"} + primary.get_latest_status_by_qid.assert_called_once_with("Q42") + fallback.get_latest_status_by_qid.assert_not_called() + + def test_primary_read_failure_reraises_by_default(self): + """Silent fallbacks hide outages — off by default.""" + primary = _make_backend_mock() + primary.get_latest_status_by_qid.side_effect = RuntimeError("primary down") + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback) + + with pytest.raises(RuntimeError, match="primary down"): + orch.get_latest_status_by_qid("Q42") + fallback.get_latest_status_by_qid.assert_not_called() + + def test_fallback_engages_when_opt_in_flag_set(self): + primary = _make_backend_mock() + primary.get_latest_status_by_qid.side_effect = RuntimeError("primary down") + fallback = _make_backend_mock() + fallback.get_latest_status_by_qid.return_value = {"qid": "Q42", "from": "fallback"} + orch = DatabaseOrchestrator( + primary=primary, fallback=fallback, auto_fallback_on_read=True, + ) + + result = orch.get_latest_status_by_qid("Q42") + + assert result == {"qid": "Q42", "from": "fallback"} + fallback.get_latest_status_by_qid.assert_called_once_with("Q42") + + +# =========================================================================== +# DatabaseOrchestrator — write routing +# =========================================================================== +class TestWriteRouting: + def test_single_mode_writes_to_primary_only(self): + primary = _make_backend_mock() + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=False) + + orch.upsert_summary_by_id("Q42", {"proveScore": 0.5}) + + primary.upsert_summary_by_id.assert_called_once_with("Q42", {"proveScore": 0.5}) + fallback.upsert_summary_by_id.assert_not_called() + + def test_dual_write_mirrors_to_both(self): + primary = _make_backend_mock() + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=True) + + orch.upsert_summary_by_id("Q42", {"proveScore": 0.5}) + + primary.upsert_summary_by_id.assert_called_once_with("Q42", {"proveScore": 0.5}) + fallback.upsert_summary_by_id.assert_called_once_with("Q42", {"proveScore": 0.5}) + + def test_dual_write_primary_succeeds_even_if_fallback_fails(self): + """During migration, primary is source of truth. Fallback errors are logged.""" + primary = _make_backend_mock() + fallback = _make_backend_mock() + fallback.upsert_summary_by_id.side_effect = RuntimeError("fallback down") + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=True) + + # Must NOT raise — the primary write succeeded. + orch.upsert_summary_by_id("Q42", {"proveScore": 0.5}) + primary.upsert_summary_by_id.assert_called_once() + + def test_primary_write_failure_still_raises(self): + primary = _make_backend_mock() + primary.upsert_summary_by_id.side_effect = RuntimeError("primary down") + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=True) + + with pytest.raises(RuntimeError, match="primary down"): + orch.upsert_summary_by_id("Q42", {"proveScore": 0.5}) + + +# =========================================================================== +# DatabaseOrchestrator — log_usage contract +# =========================================================================== +class TestLogUsageContract: + def test_orchestrator_log_usage_never_raises(self): + """Interface contract: log_usage never raises, even through the orchestrator.""" + primary = _make_backend_mock() + primary.log_usage.side_effect = RuntimeError("boom") + orch = DatabaseOrchestrator(primary=primary) + + # Must not raise. + orch.log_usage({"method": "GET"}) + + +# =========================================================================== +# DatabaseOrchestrator — claim-operations bypass dual-write +# =========================================================================== +class TestClaimOperationsPrimaryOnly: + """ + `get_next_request` and `get_request_by_id_and_reset` atomically + read+write a single row (claiming / releasing). Dual-writing would + double-claim the same row in two backends, which is wrong. + """ + + def test_get_next_request_is_primary_only_even_in_dual_write(self): + primary = _make_backend_mock() + primary.get_next_request.return_value = {"qid": "Q1"} + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=True) + + queue = MagicMock() + result = orch.get_next_request(queue) + + assert result == {"qid": "Q1"} + primary.get_next_request.assert_called_once_with(queue) + fallback.get_next_request.assert_not_called() + + def test_get_request_by_id_and_reset_is_primary_only(self): + primary = _make_backend_mock() + fallback = _make_backend_mock() + orch = DatabaseOrchestrator(primary=primary, fallback=fallback, dual_write=True) + + queue = MagicMock() + orch.get_request_by_id_and_reset(queue, "some-id") + + primary.get_request_by_id_and_reset.assert_called_once() + fallback.get_request_by_id_and_reset.assert_not_called() + + +# =========================================================================== +# DatabaseOrchestrator — attribute proxy for Mongo-specific fields +# =========================================================================== +class TestAttributeProxy: + """Legacy callers reach for `.user_collection` etc. — proxy to primary.""" + + def test_unknown_attribute_proxies_to_primary(self): + primary = _make_backend_mock() + primary.user_collection = "sentinel" + orch = DatabaseOrchestrator(primary=primary) + + assert orch.user_collection == "sentinel" + + +# =========================================================================== +# _build_from_config — config parsing +# =========================================================================== +class TestBuildFromConfig: + def test_default_empty_config_uses_mongo_single(self, monkeypatch): + """Empty `database:` block defaults to single Mongo (no orchestrator).""" + fake_mongo = _make_backend_mock() + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: fake_mongo if kind == "mongo" else None, + ) + + db = _build_from_config({}) + + # No fallback + single mode = bare backend, not wrapped. + assert db is fake_mongo + + def test_fallback_configured_wraps_in_orchestrator(self, monkeypatch): + fake_mongo = _make_backend_mock() + fake_pg = _make_backend_mock() + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: fake_mongo if kind == "mongo" else fake_pg, + ) + + db = _build_from_config({ + "primary": "mongo", + "fallback": "postgres", + "mode": "single", + }) + + assert isinstance(db, DatabaseOrchestrator) + assert db.primary is fake_mongo + assert db.fallback is fake_pg + assert db.dual_write is False + + def test_dual_write_mode(self, monkeypatch): + fake_mongo = _make_backend_mock() + fake_pg = _make_backend_mock() + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: fake_mongo if kind == "mongo" else fake_pg, + ) + + db = _build_from_config({ + "primary": "mongo", + "fallback": "postgres", + "mode": "dual-write", + }) + + assert isinstance(db, DatabaseOrchestrator) + assert db.dual_write is True + + def test_auto_fallback_on_read_flag(self, monkeypatch): + fake_mongo = _make_backend_mock() + fake_pg = _make_backend_mock() + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: fake_mongo if kind == "mongo" else fake_pg, + ) + + db = _build_from_config({ + "primary": "postgres", + "fallback": "mongo", + "auto_fallback_on_read": True, + }) + + assert isinstance(db, DatabaseOrchestrator) + assert db.auto_fallback_on_read is True + + +# =========================================================================== +# get_database — factory + caching +# =========================================================================== +class TestGetDatabaseFactory: + def test_inline_config_skips_cache(self, monkeypatch): + build_calls = [] + + def fake_build(kind, _settings): + build_calls.append(kind) + return _make_backend_mock() + + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", fake_build, + ) + + # Two inline-config calls = two backend builds, no caching. + get_database(config={"primary": "mongo"}) + get_database(config={"primary": "mongo"}) + + assert len(build_calls) == 2 + + def test_file_based_config_is_cached(self, monkeypatch, tmp_path): + build_calls = [] + + def fake_build(kind, _settings): + build_calls.append(kind) + return _make_backend_mock() + + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", fake_build, + ) + + config_file = tmp_path / "config.yaml" + config_file.write_text("database:\n primary: mongo\n") + + db1 = get_database(config_path=str(config_file)) + db2 = get_database(config_path=str(config_file)) + + assert db1 is db2 + assert len(build_calls) == 1 # built once, reused + + def test_missing_config_file_falls_back_to_mongo_default(self, monkeypatch): + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: _make_backend_mock() if kind == "mongo" else None, + ) + + db = get_database(config_path="/definitely/not/a/real/path.yaml") + + assert db is not None + + def test_refresh_bypasses_cache(self, monkeypatch, tmp_path): + build_calls = [] + monkeypatch.setattr( + "prove_shared.database.orchestrator._build_backend", + lambda kind, _: (build_calls.append(kind), _make_backend_mock())[1], + ) + + config_file = tmp_path / "config.yaml" + config_file.write_text("database:\n primary: mongo\n") + + get_database(config_path=str(config_file)) + get_database(config_path=str(config_file), refresh=True) + + assert len(build_calls) == 2 diff --git a/prove-shared/tests/test_postgres.py b/prove-shared/tests/test_postgres.py new file mode 100644 index 0000000..a4f9403 --- /dev/null +++ b/prove-shared/tests/test_postgres.py @@ -0,0 +1,96 @@ +""" +Tests for the `PostgreSQLHandler` stub. + +The stub is a placeholder — every method must raise `NotImplementedError` so +callers fail loudly instead of silently doing nothing. Once real SQL starts +landing, the individual `test_*_raises_not_implemented` tests will flip to +real behavioural tests one at a time. + +Separately, this suite proves the stub still *satisfies* `DataStore` (ABC), +so the orchestrator can legitimately construct one. +""" +from datetime import datetime + +import pandas as pd +import pytest + +from prove_shared.database.interface import DataStore +from prove_shared.database.postgres import PostgreSQLHandler + + +@pytest.fixture +def pg() -> PostgreSQLHandler: + """Fresh stub; stub doesn't open any real connection in __init__.""" + return PostgreSQLHandler() + + +# =========================================================================== +# Structural: stub implements the interface +# =========================================================================== +def test_stub_is_a_datastore(pg): + """Proves every @abstractmethod on DataStore is implemented (even if only as a stub).""" + assert isinstance(pg, DataStore) + + +def test_stub_carries_connection_params(pg): + assert pg.dsn == "postgresql://localhost/prove" + assert pg.max_retries == 3 + + +def test_custom_dsn_and_retries(): + pg = PostgreSQLHandler(dsn="postgresql://h/d", max_retries=5) + assert pg.dsn == "postgresql://h/d" + assert pg.max_retries == 5 + + +# =========================================================================== +# Behavioural: every method raises NotImplementedError with a clear message +# =========================================================================== +# The parametrisation pattern means each failure will show the exact method +# name in the pytest output — easy to spot when someone implements a method +# and forgets to update the test. + +_READ_CASES = [ + ("ensure_connection", ()), + ("get_latest_status_by_qid", ("Q42",)), + ("get_statuses_by_qid", ("Q42",)), + ("get_html_by_task_id", ("t1",)), + ("get_entailments_by_task_and_reference", ("t1", "r1")), + ("aggregate_entailments_by_task_id", ("t1", ["r1"])), + ("get_summary_by_id", ("Q42",)), + ("get_parser_stats_by_task_and_entity", ("t1", "Q42")), + ("get_queue_items", ("user",)), + ("find_queue_item_by_qid", ("user", "Q42")), + ("get_usage_records", ()), + ("get_next_request", (object(),)), + ("get_request_by_id_and_reset", (object(), "some-id")), + ("get_request_by_id", (object(), "some-id")), + ("get_request_by_taskid", (object(), "t1")), + ("get_all_request_in_progress", (object(),)), +] + + +_WRITE_CASES = [ + ("save_html_content", (pd.DataFrame(),)), + ("save_entailment_results", (pd.DataFrame(),)), + ("save_parser_stats", ({"entity_id": "Q42", "task_id": "t1"},)), + ("save_status", ({"task_id": "t1", "qid": "Q42"},)), + ("upsert_summary_by_id", ("Q42", {"proveScore": 0.5})), + ("enqueue_item", ("user", {"qid": "Q42"})), + ("log_usage", ({"method": "GET"},)), + ("increment_retry_by_id", ("user", "item-1")), + ("mark_queue_item_error_by_id", ("user", "item-1")), + ("update_queue_status_by_task_and_qid", ("user", "t1", "Q42", "completed")), + ( + "set_request_status_and_processing_time", + (object(), "completed", datetime(2024, 1, 1), "some-id"), + ), +] + + +@pytest.mark.parametrize("method_name,args", _READ_CASES + _WRITE_CASES) +def test_every_method_raises_not_implemented(pg, method_name, args): + """Each stub method raises with a message pointing at itself.""" + with pytest.raises(NotImplementedError) as exc_info: + getattr(pg, method_name)(*args) + assert f"PostgreSQLHandler.{method_name}" in str(exc_info.value)