diff --git a/src/pds/registrysweepers/utils/db/__init__.py b/src/pds/registrysweepers/utils/db/__init__.py index 0ec5abd..8bb0d7a 100644 --- a/src/pds/registrysweepers/utils/db/__init__.py +++ b/src/pds/registrysweepers/utils/db/__init__.py @@ -1,7 +1,9 @@ +import functools import json import logging import sys import urllib.parse +from datetime import timedelta from typing import Callable from typing import Dict from typing import Iterable @@ -10,9 +12,9 @@ from typing import Optional import requests +from opensearchpy import OpenSearch from pds.registrysweepers.utils.db.host import Host from pds.registrysweepers.utils.db.update import Update -from pds.registrysweepers.utils.misc import auto_raise_for_status from pds.registrysweepers.utils.misc import get_random_hex_id from requests import HTTPError from retry import retry @@ -22,7 +24,7 @@ def query_registry_db( - host: Host, + client: OpenSearch, query: Dict, _source: Dict, index_name: str = "registry", @@ -30,22 +32,15 @@ def query_registry_db( scroll_keepalive_minutes: int = 10, ) -> Iterable[Dict]: """ - Given an OpenSearch host and query/_source, return an iterable collection of hits + Given an OpenSearch client and query/_source, return an iterable collection of hits - Example query: {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} + Example query: {"query: {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}}} Example _source: {"includes": ["lidvid"]} """ - req_content = { - "query": query, - "_source": _source, - "size": page_size, - } - + scroll_keepalive = f"{scroll_keepalive_minutes}m" query_id = get_random_hex_id() # This is just used to differentiate queries during logging - log.info(f"Initiating query with id {query_id}: {req_content}") - - path = f"{index_name}/_search?scroll={scroll_keepalive_minutes}m" + log.info(f"Initiating query with id {query_id}: {query}") served_hits = 0 @@ -53,28 +48,35 @@ def query_registry_db( log.info(f"Query {query_id} progress: 0%") more_data_exists = True + scroll_id = None while more_data_exists: - resp = retry_call( - auto_raise_for_status(requests.get), - fargs=[urllib.parse.urljoin(host.url, path)], - fkwargs={"auth": (host.username, host.password), "verify": host.verify, "json": req_content}, - exceptions=(HTTPError, RuntimeError), + if scroll_id is None: + fetch_func = lambda: client.search( + index=index_name, + body=query, + scroll=scroll_keepalive, + size=page_size, + _source_includes=_source.get("includes", []), # TODO: Break out from the enclosing _source object + _source_excludes=_source.get("excludes", []), # TODO: Break out from the enclosing _source object + ) + else: + fetch_func = lambda: client.scroll(scroll_id=scroll_id, scroll=scroll_keepalive) + + results = retry_call( + fetch_func, tries=6, delay=2, backoff=2, logger=log, ) + scroll_id = results.get("_scroll_id") - data = resp.json() - path = "_search/scroll" - req_content = {"scroll": f"{scroll_keepalive_minutes}m", "scroll_id": data["_scroll_id"]} - - total_hits = data["hits"]["total"]["value"] + total_hits = results["hits"]["total"]["value"] log.debug( f" paging query {query_id} ({served_hits} to {min(served_hits + page_size, total_hits)} of {total_hits})" ) - response_hits = data["hits"]["hits"] + response_hits = results["hits"]["hits"] for hit in response_hits: served_hits += 1 @@ -92,24 +94,20 @@ def query_registry_db( hits_data_present_in_response = len(response_hits) > 0 if not hits_data_present_in_response: log.error( - f"Response for query {query_id} contained no hits when hits were expected. Returned data is incomplete (got {served_hits} of {total_hits} total hits). Response was: {data}" + f"Response for query {query_id} contained no hits when hits were expected. Returned data is incomplete (got {served_hits} of {total_hits} total hits). Response was: {results}" ) break - more_data_exists = served_hits < data["hits"]["total"]["value"] + more_data_exists = served_hits < results["hits"]["total"]["value"] - # TODO: Determine if the following block is actually necessary - if "scroll_id" in req_content: - path = f'_search/scroll/{req_content["scroll_id"]}' - retry_call( - auto_raise_for_status(requests.delete), - fargs=[urllib.parse.urljoin(host.url, path)], - fkwargs={"auth": (host.username, host.password), "verify": host.verify}, - tries=6, - delay=2, - backoff=2, - logger=log, - ) + retry_call( + client.clear_scroll, + fkwargs={"scroll_id": scroll_id}, + tries=6, + delay=2, + backoff=2, + logger=log, + ) log.info(f"Query {query_id} complete!") @@ -118,7 +116,7 @@ def query_registry_db_or_mock(mock_f: Optional[Callable[[str], Iterable[Dict]]], if mock_f is not None: def mock_wrapper( - host: Host, + client: OpenSearch, query: Dict, _source: Dict, index_name: str = "registry", @@ -132,7 +130,7 @@ def mock_wrapper( return query_registry_db -def write_updated_docs(host: Host, updates: Iterable[Update], index_name: str = "registry"): +def write_updated_docs(client: OpenSearch, updates: Iterable[Update], index_name: str = "registry"): log.info("Updating a lazily-generated collection of product documents...") updated_doc_count = 0 @@ -145,7 +143,7 @@ def write_updated_docs(host: Host, updates: Iterable[Update], index_name: str = log.info( f"Bulk update buffer has reached {bulk_buffer_max_size_mb}MB threshold - writing {pending_product_count} document updates to db..." ) - _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) + _write_bulk_updates_chunk(client, index_name, bulk_updates_buffer) bulk_updates_buffer = [] bulk_buffer_size_mb = 0.0 @@ -158,11 +156,10 @@ def write_updated_docs(host: Host, updates: Iterable[Update], index_name: str = updated_doc_count += 1 remaining_products_to_write_count = int(len(bulk_updates_buffer) / 2) - updated_doc_count += remaining_products_to_write_count if len(bulk_updates_buffer) > 0: log.info(f"Writing documents updates for {remaining_products_to_write_count} remaining products to db...") - _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) + _write_bulk_updates_chunk(client, index_name, bulk_updates_buffer) log.info(f"Updated documents for {updated_doc_count} total products!") @@ -174,25 +171,12 @@ def update_as_statements(update: Update) -> Iterable[str]: return updates_strs -@retry(exceptions=(HTTPError, RuntimeError), tries=6, delay=2, backoff=2, logger=log) -def _write_bulk_updates_chunk(host: Host, index_name: str, bulk_updates: Iterable[str]): - headers = {"Content-Type": "application/x-ndjson"} - path = f"{index_name}/_bulk" - +@retry(tries=6, delay=2, backoff=2, logger=log) +def _write_bulk_updates_chunk(client: OpenSearch, index_name: str, bulk_updates: Iterable[str]): bulk_data = "\n".join(bulk_updates) + "\n" - response = requests.put( - urllib.parse.urljoin(host.url, path), - auth=(host.username, host.password), - data=bulk_data, - headers=headers, - verify=host.verify, - ) + response_content = client.bulk(index=index_name, body=bulk_data) - # N.B. HTTP status 200 is insufficient as a success check for _bulk API. - # See: https://github.com/elastic/elasticsearch/issues/41434 - response.raise_for_status() - response_content = response.json() if response_content.get("errors"): warn_types = {"document_missing_exception"} # these types represent bad data, not bad sweepers behaviour items_with_problems = [item for item in response_content["items"] if "error" in item["update"]] @@ -239,16 +223,18 @@ def aggregate_update_error_types(items: Iterable[Dict]) -> Mapping[str, Dict[str return agg -def get_extant_lidvids(host: Host) -> Iterable[str]: +def get_extant_lidvids(client: OpenSearch) -> Iterable[str]: """ Given an OpenSearch host, return all extant LIDVIDs """ log.info("Retrieving extant LIDVIDs") - query = {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} + query = { + "query": {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} + } _source = {"includes": ["lidvid"]} - results = query_registry_db(host, query, _source, scroll_keepalive_minutes=1) + results = query_registry_db(client, query, _source, scroll_keepalive_minutes=1) return map(lambda doc: doc["_source"]["lidvid"], results)