diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..356602581 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +# Add pyupgrade diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 69c62bf93..770cb4bc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,12 @@ -exclude: Pipfile\.lock|migrations|\.idea +exclude: Pipfile\.lock|migrations|\.idea|node_modules|archive repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 hooks: - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - id: check-docstring-first - id: check-executables-have-shebangs - id: check-json - id: check-case-conflict @@ -19,6 +21,7 @@ repos: args: - --remove - id: pretty-format-json + exclude: package(-lock)?\.json args: - --autofix - id: requirements-txt-fixer @@ -29,6 +32,26 @@ repos: - id: isort files: \.py$ exclude: ^build/.*$|^.tox/.*$|^venv/.*$ + args: + - --lines-after-imports=2 + - --multi-line=3 + - --trailing-comma + - --force-grid-wrap=0 + - --use-parentheses + - --ensure-newline-before-comments + - --line-length=88 + + - repo: https://github.com/asottile/pyupgrade + rev: v3.2.2 + hooks: + - id: pyupgrade + args: + - --py310-plus + + - repo: https://github.com/PyCQA/flake8 + rev: 3.9.2 + hooks: + - id: flake8 - repo: https://github.com/ambv/black rev: 22.3.0 @@ -37,16 +60,23 @@ repos: args: - --safe - - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + - repo: https://github.com/pre-commit/mirrors-eslint + rev: v8.3.0 hooks: - - id: flake8 + - id: eslint + files: ^js/.*$ + additional_dependencies: + - eslint@8.3.0 + - eslint-config-prettier@8.3.0 - repo: https://github.com/pre-commit/mirrors-prettier rev: v2.5.0 hooks: + - id: prettier + files: ^js/.*$ - id: prettier types: [yaml] + - repo: https://github.com/koalaman/shellcheck-precommit rev: v0.8.0 hooks: diff --git a/api/catalog/api/constants/field_order.py b/api/catalog/api/constants/field_order.py index a67f4b17f..5b37e634c 100644 --- a/api/catalog/api/constants/field_order.py +++ b/api/catalog/api/constants/field_order.py @@ -45,4 +45,4 @@ field_position_map: dict[str, int] = { field: idx for idx, field in enumerate(json_fields) } -"""mapping of JSON fields to their sort positions""" +#: mapping of JSON fields to their sort positions diff --git a/api/catalog/api/controllers/search_controller.py b/api/catalog/api/controllers/search_controller.py index bacae6b7b..99fb71aaf 100644 --- a/api/catalog/api/controllers/search_controller.py +++ b/api/catalog/api/controllers/search_controller.py @@ -5,7 +5,7 @@ import pprint from itertools import accumulate from math import ceil -from typing import Any, List, Literal, Optional, Tuple +from typing import Any, Literal from django.conf import settings from django.core.cache import cache @@ -52,7 +52,7 @@ def _unmasked_query_end(page_size, page): def _paginate_with_dead_link_mask( s: Search, page_size: int, page: int -) -> Tuple[int, int]: +) -> tuple[int, int]: """ Given a query, a page and page_size, return the start and end of the slice of results. @@ -117,8 +117,8 @@ def _paginate_with_dead_link_mask( def _get_query_slice( - s: Search, page_size: int, page: int, filter_dead: Optional[bool] = False -) -> Tuple[int, int]: + s: Search, page_size: int, page: int, filter_dead: bool | None = False +) -> tuple[int, int]: """ Select the start and end of the search results for this query. """ @@ -147,7 +147,7 @@ def _quote_escape(query_string): def _post_process_results( s, start, end, page_size, search_results, request, filter_dead -) -> Optional[List[Hit]]: +) -> list[Hit] | None: """ After fetching the search results from the back end, iterate through the results, perform image validation, and route certain thumbnails through our @@ -222,7 +222,7 @@ def _apply_filter( # Any is used here to avoid a circular import search_params: Any, # MediaSearchRequestSerializer serializer_field: str, - es_field: Optional[str] = None, + es_field: str | None = None, behaviour: Literal["filter", "exclude"] = "filter", ): """ @@ -285,7 +285,7 @@ def search( request: Request, filter_dead: bool, page: int = 1, -) -> Tuple[List[Hit], int, int]: +) -> tuple[list[Hit], int, int]: """ Given a set of keywords and an optional set of filters, perform a ranked paginated search. @@ -489,8 +489,8 @@ def get_sources(index): def _get_result_and_page_count( - response_obj: Response, results: Optional[List[Hit]], page_size: int -) -> Tuple[int, int]: + response_obj: Response, results: list[Hit] | None, page_size: int +) -> tuple[int, int]: """ Elasticsearch does not allow deep pagination of ranked queries. Adjust returned page count to reflect this. diff --git a/api/catalog/api/models/audio.py b/api/catalog/api/models/audio.py index 8d9380c67..a90c07278 100644 --- a/api/catalog/api/models/audio.py +++ b/api/catalog/api/models/audio.py @@ -22,7 +22,7 @@ class AltAudioFile(AbstractAltFile): def __init__(self, attrs): self.bit_rate = attrs.get("bit_rate") self.sample_rate = attrs.get("sample_rate") - super(AltAudioFile, self).__init__(attrs) + super().__init__(attrs) @property def sample_rate_in_khz(self): @@ -263,7 +263,7 @@ class Meta: @property def audio_url(self): - return super(AudioReport, self).url("audio") + return super().url("audio") class AudioList(AbstractMediaList): @@ -278,4 +278,4 @@ class Meta: def save(self, *args, **kwargs): self.slug = uuslug(self.title, instance=self) - super(AudioList, self).save(*args, **kwargs) + super().save(*args, **kwargs) diff --git a/api/catalog/api/models/image.py b/api/catalog/api/models/image.py index 3502c51f7..3e91f2369 100644 --- a/api/catalog/api/models/image.py +++ b/api/catalog/api/models/image.py @@ -84,7 +84,7 @@ class Meta: @property def image_url(self): - return super(ImageReport, self).url("photos") + return super().url("photos") class ImageList(AbstractMediaList): @@ -99,4 +99,4 @@ class Meta: def save(self, *args, **kwargs): self.slug = uuslug(self.title, instance=self) - super(ImageList, self).save(*args, **kwargs) + super().save(*args, **kwargs) diff --git a/api/catalog/api/models/media.py b/api/catalog/api/models/media.py index 2ed74f058..15eea5c31 100644 --- a/api/catalog/api/models/media.py +++ b/api/catalog/api/models/media.py @@ -216,7 +216,7 @@ def save(self, *args, **kwargs): same_reports = same_reports.filter(reason=self.reason) same_reports.update(status=self.status) - super(AbstractMediaReport, self).save(*args, **kwargs) + super().save(*args, **kwargs) class AbstractDeletedMedia(OpenLedgerModel): diff --git a/api/catalog/api/utils/attribution.py b/api/catalog/api/utils/attribution.py index 8aa97f3c0..d8dd4b1c3 100644 --- a/api/catalog/api/utils/attribution.py +++ b/api/catalog/api/utils/attribution.py @@ -4,17 +4,15 @@ frontend, or open an issue to track it. """ -from typing import Optional - from catalog.api.utils.licenses import get_full_license_name, is_public_domain def get_attribution_text( - title: Optional[str], - creator: Optional[str], + title: str | None, + creator: str | None, _license: str, - license_version: Optional[str], - license_url: Optional[str], + license_version: str | None, + license_url: str | None, ) -> str: """ Get the attribution text to properly and legally attribute a creative work to its diff --git a/api/catalog/api/utils/ccrel.py b/api/catalog/api/utils/ccrel.py index 13392c581..58cc59bb7 100644 --- a/api/catalog/api/utils/ccrel.py +++ b/api/catalog/api/utils/ccrel.py @@ -1,11 +1,3 @@ -import io -import os -import uuid - -import libxmp -from libxmp.consts import XMP_NS_CC, XMP_NS_XMP, XMP_NS_XMP_Rights - - """ Tools for embedding Creative Commons Rights Expression Language (ccREL) data into files using Extensible Metadata Platform (XMP). @@ -17,6 +9,13 @@ [0] https://www.w3.org/Submission/ccREL/ """ +import io +import os +import uuid + +import libxmp +from libxmp.consts import XMP_NS_CC, XMP_NS_XMP, XMP_NS_XMP_Rights + def embed_xmp_bytes(image: io.BytesIO, work_properties): """ diff --git a/api/catalog/api/utils/dead_link_mask.py b/api/catalog/api/utils/dead_link_mask.py index 59176966a..8980439fa 100644 --- a/api/catalog/api/utils/dead_link_mask.py +++ b/api/catalog/api/utils/dead_link_mask.py @@ -1,5 +1,3 @@ -from typing import List - from deepdiff import DeepHash from django_redis import get_redis_connection from elasticsearch_dsl import Search @@ -25,7 +23,7 @@ def get_query_hash(s: Search) -> str: return deep_hash -def get_query_mask(query_hash: str) -> List[int]: +def get_query_mask(query_hash: str) -> list[int]: """ Fetches an existing query mask for a given query hash or returns an empty one. @@ -38,7 +36,7 @@ def get_query_mask(query_hash: str) -> List[int]: return list(map(int, redis.lrange(key, 0, -1))) -def save_query_mask(query_hash: str, mask: List): +def save_query_mask(query_hash: str, mask: list): """ Saves a query mask to redis. diff --git a/api/catalog/api/utils/help_text.py b/api/catalog/api/utils/help_text.py index a17e5c834..d041e765f 100644 --- a/api/catalog/api/utils/help_text.py +++ b/api/catalog/api/utils/help_text.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable def make_comma_separated_help_text(items: Iterable[str], name: str) -> str: diff --git a/api/catalog/api/utils/licenses.py b/api/catalog/api/utils/licenses.py index 9c6a89dcf..6d658322e 100644 --- a/api/catalog/api/utils/licenses.py +++ b/api/catalog/api/utils/licenses.py @@ -4,7 +4,6 @@ frontend, or open an issue to track it. """ -from typing import Optional from catalog.api.constants.licenses import ( ALL_CC_LICENSES, @@ -13,7 +12,7 @@ ) -def get_license_url(_license: str, license_version: Optional[str]) -> str: +def get_license_url(_license: str, license_version: str | None) -> str: """ Get the URL to the deed of the license. @@ -33,7 +32,7 @@ def get_license_url(_license: str, license_version: Optional[str]) -> str: return f"https://creativecommons.org/{fragment}/" -def get_full_license_name(_license: str, license_version: Optional[str]) -> str: +def get_full_license_name(_license: str, license_version: str | None) -> str: """ Get the full name of the license in a displayable format from the license slug and version. diff --git a/api/catalog/api/utils/waveform.py b/api/catalog/api/utils/waveform.py index 2bb00214a..77bac9954 100644 --- a/api/catalog/api/utils/waveform.py +++ b/api/catalog/api/utils/waveform.py @@ -6,7 +6,6 @@ import pathlib import shutil import subprocess -from typing import List from django.conf import settings @@ -141,7 +140,7 @@ def cleanup(file_name): logger.debug("file not found, nothing deleted") -def generate_peaks(audio) -> List[float]: +def generate_peaks(audio) -> list[float]: file_name = None try: file_name = download_audio(audio.url, audio.identifier) diff --git a/api/catalog/configuration/elasticsearch.py b/api/catalog/configuration/elasticsearch.py index f6495f8ce..d32364287 100644 --- a/api/catalog/configuration/elasticsearch.py +++ b/api/catalog/configuration/elasticsearch.py @@ -1,3 +1,7 @@ +""" +This file contains configuration pertaining to Elasticsearch. +""" + from django.conf import settings from aws_requests_auth.aws_auth import AWSRequestsAuth @@ -42,11 +46,12 @@ def _elasticsearch_connect(): ES = _elasticsearch_connect() -"""Elasticsearch client, also aliased to connection 'default'""" +#: Elasticsearch client, also aliased to connection 'default' + connections.add_connection("default", ES) MEDIA_INDEX_MAPPING = { media_type: config(f"{media_type.upper()}_INDEX_NAME", default=media_type) for media_type in MEDIA_TYPES } -"""mapping of media types to Elasticsearch index names""" +#: mapping of media types to Elasticsearch index names diff --git a/api/catalog/configuration/link_validation_cache.py b/api/catalog/configuration/link_validation_cache.py index 582f0d8d9..b0cdf9e6e 100644 --- a/api/catalog/configuration/link_validation_cache.py +++ b/api/catalog/configuration/link_validation_cache.py @@ -3,7 +3,6 @@ import os from collections import defaultdict from datetime import timedelta -from typing import Optional from django.core.exceptions import ImproperlyConfigured @@ -53,7 +52,7 @@ def __init__(self): self[status] = value - def _config(self, key: str | int, default: Optional[dict] = None) -> Optional[int]: + def _config(self, key: str | int, default: dict | None = None) -> int | None: try: v = config( f"{self.SETTING_PREFIX}{str(key)}", diff --git a/api/catalog/urls/swagger.py b/api/catalog/urls/swagger.py index a241c4e79..49cac6fdf 100644 --- a/api/catalog/urls/swagger.py +++ b/api/catalog/urls/swagger.py @@ -12,7 +12,7 @@ "docs", "README.md", ) -with open(description_path, "r") as description_file: +with open(description_path) as description_file: description = description_file.read() tos_url = "https://wordpress.github.io/openverse-api/terms_of_service.html" diff --git a/api/test/api_live_integration.py b/api/test/api_live_integration.py index 70c31140b..9495d9916 100644 --- a/api/test/api_live_integration.py +++ b/api/test/api_live_integration.py @@ -1,3 +1,11 @@ +""" +**These are the LEGACY API integration tests; do not add further tests here.** +New tests should be added in v1_integration_test. + +End-to-end API tests. Can be used to verify a live deployment is functioning as +designed. Run with the `pytest -s` command from this directory. +""" + import json import os import uuid @@ -10,15 +18,6 @@ from catalog.api.utils.watermark import watermark -""" -**These are the LEGACY API integration tests; do not add further tests here.** -New tests should be added in v1_integration_test. - -End-to-end API tests. Can be used to verify a live deployment is functioning as -designed. Run with the `pytest -s` command from this directory. -""" - - API_URL = os.getenv("INTEGRATION_TEST_URL", "http://localhost:8000") known_apis = { "http://localhost:8000": "LOCAL", @@ -60,10 +59,10 @@ def test_search_consistency(): appear in the first few pages of a search query. """ n_pages = 5 - searches = set( + searches = { requests.get(f"{API_URL}/image/search?q=honey;page={page}", verify=False) for page in range(1, n_pages) - ) + } images = set() for response in searches: diff --git a/api/test/api_live_search_qa.py b/api/test/api_live_search_qa.py index 096d61414..d5ff9bdf6 100644 --- a/api/test/api_live_search_qa.py +++ b/api/test/api_live_search_qa.py @@ -1,8 +1,3 @@ -import json - -import requests - - """ Tests to run against a live instance of Openverse with a significant (10M+) number of records. Quality of search rankings can be affected by the number of @@ -10,6 +5,11 @@ do not accurately model relevance at scale. """ +import json + +import requests + + API_URL = "https://api-dev.openverse.engineering" diff --git a/api/test/dead_link_filter_test.py b/api/test/dead_link_filter_test.py index 91331075d..ba03924fb 100644 --- a/api/test/dead_link_filter_test.py +++ b/api/test/dead_link_filter_test.py @@ -112,8 +112,8 @@ def test_dead_link_filtering(mocked_map, client): data_with_dead_links = res_with_dead_links.json() data_without_dead_links = res_without_dead_links.json() - res_1_ids = set(result["id"] for result in data_with_dead_links["results"]) - res_2_ids = set(result["id"] for result in data_without_dead_links["results"]) + res_1_ids = {result["id"] for result in data_with_dead_links["results"]} + res_2_ids = {result["id"] for result in data_without_dead_links["results"]} # In this case, both have 20 results as the dead link filter has "back filled" the # pages of dead links. See the subsequent test for the case when this does not # occur (i.e., when the entire first page of links is dead). diff --git a/api/test/media_integration.py b/api/test/media_integration.py index aabb581e2..99d30e474 100644 --- a/api/test/media_integration.py +++ b/api/test/media_integration.py @@ -64,10 +64,10 @@ def search_consistency( appear in the first few pages of a search query. """ - searches = set( + searches = { requests.get(f"{API_URL}/v1/{media_path}?page={page}", verify=False) for page in range(1, n_pages) - ) + } results = set() for response in searches: diff --git a/api/test/search_qa.py b/api/test/search_qa.py index 89aae31ed..4c21a57f9 100644 --- a/api/test/search_qa.py +++ b/api/test/search_qa.py @@ -1,3 +1,7 @@ +""" +Perform some basic tests to ensure that search rankings work as anticipated. +""" + import json import pprint from enum import Enum @@ -8,11 +12,6 @@ from .api_live_integration import API_URL -""" -Perform some basic tests to ensure that search rankings work as anticipated. -""" - - class QAScores(Enum): TARGET = 1 LESS_RELEVANT = 2 diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index f2c29d025..1121575cb 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -1,6 +1,6 @@ import random +from collections.abc import Callable from enum import Enum, auto -from typing import Callable from unittest import mock from uuid import uuid4 @@ -137,7 +137,7 @@ class CreateMaskConfig(Enum): @pytest.fixture(name="create_mask") -def create_mask_fixture() -> Callable[(Search, int, int), None]: +def create_mask_fixture() -> Callable[[Search, int, int], None]: created_masks = [] def create_mask( diff --git a/api/test/unit/utils/validate_images_test.py b/api/test/unit/utils/validate_images_test.py index afbaeb954..2a7889505 100644 --- a/api/test/unit/utils/validate_images_test.py +++ b/api/test/unit/utils/validate_images_test.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import pytest from fakeredis import FakeRedis @@ -26,7 +26,7 @@ def get_redis_connection(*args, **kwargs): class GRequestsFixture: requests: list[AsyncRequest] response_factory: Callable[ - (AsyncRequest,), Response + [AsyncRequest], Response ] = lambda x: GRequestsFixture._default_response_factory(x) @staticmethod diff --git a/api/test/unit/utils/watermark_test.py b/api/test/unit/utils/watermark_test.py index 10ca2ff8c..c649140dd 100644 --- a/api/test/unit/utils/watermark_test.py +++ b/api/test/unit/utils/watermark_test.py @@ -1,9 +1,9 @@ import json import struct +from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import Callable from unittest import mock import pytest @@ -22,7 +22,7 @@ class RequestsFixture: requests: list[Request] response_factory: Callable[ - (Request,), Response + [Request], Response ] = lambda x: RequestsFixture._default_response_factory(x) @staticmethod diff --git a/api/test/unit/utils/waveform_test.py b/api/test/unit/utils/waveform_test.py index 1f6326688..d0b70f152 100644 --- a/api/test/unit/utils/waveform_test.py +++ b/api/test/unit/utils/waveform_test.py @@ -1,8 +1,8 @@ import json +from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from pathlib import Path -from typing import Callable import pytest from requests import Request, Response @@ -20,7 +20,7 @@ class RequestsFixture: requests: list[Request] response_factory: Callable[ - (Request,), Response + [Request], Response ] = lambda x: RequestsFixture._default_response_factory(x) @staticmethod diff --git a/api/test/unit/views/image_views_test.py b/api/test/unit/views/image_views_test.py index f7336f413..a74339dbb 100644 --- a/api/test/unit/views/image_views_test.py +++ b/api/test/unit/views/image_views_test.py @@ -1,8 +1,8 @@ import json +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from test.factory.models.image import ImageFactory -from typing import Callable from rest_framework.test import APIClient @@ -26,7 +26,7 @@ def api_client(): class RequestsFixture: requests: list[Request] response_factory: Callable[ - (Request,), Response + [Request], Response ] = lambda x: RequestsFixture._default_response_factory(x) @staticmethod diff --git a/api/test/unit/views/media_views_test.py b/api/test/unit/views/media_views_test.py index 87d83066d..123de114f 100644 --- a/api/test/unit/views/media_views_test.py +++ b/api/test/unit/views/media_views_test.py @@ -1,9 +1,9 @@ import json +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from test.factory.models.audio import AudioFactory from test.factory.models.image import ImageFactory -from typing import Callable from unittest import mock from unittest.mock import patch @@ -37,7 +37,7 @@ class SentRequest: class RequestsFixture: sent_requests: list[SentRequest] send_handler: Callable[ - (Request,), Response + [Request], Response ] = lambda *args, **kwargs: RequestsFixture._default_send_handler(*args, **kwargs) response_queue: list[Response] = field(default_factory=list) diff --git a/ingestion_server/ingestion_server/es_helpers.py b/ingestion_server/ingestion_server/es_helpers.py index d7463e926..ebd9dc627 100644 --- a/ingestion_server/ingestion_server/es_helpers.py +++ b/ingestion_server/ingestion_server/es_helpers.py @@ -1,6 +1,6 @@ import logging as log import time -from typing import NamedTuple, Optional, Union +from typing import NamedTuple from aws_requests_auth.aws_auth import AWSRequestsAuth from decouple import config @@ -14,8 +14,8 @@ class Stat(NamedTuple): """ exists: bool - is_alias: Optional[bool] - alt_names: Optional[Union[str, list[str]]] + is_alias: bool | None + alt_names: str | list[str] | None def elasticsearch_connect(timeout: int = 300) -> Elasticsearch: diff --git a/ingestion_server/ingestion_server/indexer.py b/ingestion_server/ingestion_server/indexer.py index ea6ddd3fd..53e4d6cb3 100644 --- a/ingestion_server/ingestion_server/indexer.py +++ b/ingestion_server/ingestion_server/indexer.py @@ -19,8 +19,7 @@ import time import uuid from collections import deque -from multiprocessing import Value -from typing import Optional +from typing import Any import elasticsearch import psycopg2 @@ -113,11 +112,13 @@ class TableIndexer: def __init__( self, es_instance: Elasticsearch, - task_id: Optional[str] = None, - callback_url: Optional[str] = None, - progress: Optional[Value] = None, - active_workers: Optional[Value] = None, - is_bad_request: Optional[Value] = None, + task_id: str | None = None, + callback_url: str | None = None, + # The following arguments should be typed as ``Synchronized | None``. + # https://github.com/python/typeshed/issues/8799 + progress: Any = None, + active_workers: Any = None, + is_bad_request: Any = None, ): self.es = es_instance connections.connections.add_connection("default", self.es) @@ -420,8 +421,8 @@ def point_alias(self, model_name: str, index_suffix: str, alias: str, **_): def delete_index( self, model_name: str, - index_suffix: Optional[str] = None, - alias: Optional[str] = None, + index_suffix: str | None = None, + alias: str | None = None, force_delete: bool = False, **_, ): diff --git a/ingestion_server/ingestion_server/ingest.py b/ingestion_server/ingestion_server/ingest.py index 27747b732..887fc8b88 100644 --- a/ingestion_server/ingestion_server/ingest.py +++ b/ingestion_server/ingestion_server/ingest.py @@ -45,14 +45,14 @@ "RELATIVE_UPSTREAM_DB_HOST", default=UPSTREAM_DB_HOST, ) -"""The hostname of the upstream DB from the POV of the downstream DB""" +#: the hostname of the upstream DB from the POV of the downstream DB RELATIVE_UPSTREAM_DB_PORT = config( "RELATIVE_UPSTREAM_DB_PORT", default=UPSTREAM_DB_PORT, cast=int, ) -"""The port of the upstream DB from the POV of the downstream DB""" +#: the port of the upstream DB from the POV of the downstream DB def _get_shared_cols(downstream, upstream, table: str): @@ -69,9 +69,9 @@ def _get_shared_cols(downstream, upstream, table: str): with downstream.cursor() as cur1, upstream.cursor() as cur2: get_tables = SQL("SELECT * FROM {table} LIMIT 0;") cur1.execute(get_tables.format(table=Identifier(table))) - conn1_cols = set([desc[0] for desc in cur1.description]) + conn1_cols = {desc[0] for desc in cur1.description} cur2.execute(get_tables.format(table=Identifier(f"{table}_view"))) - conn2_cols = set([desc[0] for desc in cur2.description]) + conn2_cols = {desc[0] for desc in cur2.description} shared = list(conn1_cols.intersection(conn2_cols)) log.info(f"Shared columns: {shared}") diff --git a/ingestion_server/ingestion_server/queries.py b/ingestion_server/ingestion_server/queries.py index eb9f65a5c..57fa865ec 100644 --- a/ingestion_server/ingestion_server/queries.py +++ b/ingestion_server/ingestion_server/queries.py @@ -1,5 +1,4 @@ from textwrap import dedent as d -from typing import Optional from psycopg2.sql import SQL, Identifier from psycopg2.sql import Literal as PgLiteral @@ -93,7 +92,7 @@ def get_copy_data_query( table: str, columns: list[str], approach: ApproachType, - limit: Optional[int] = 100_000, + limit: int | None = 100_000, ): """ Get the query for copying data from the upstream table to a temporary table diff --git a/ingestion_server/ingestion_server/tasks.py b/ingestion_server/ingestion_server/tasks.py index 76379bc5a..f628f6525 100644 --- a/ingestion_server/ingestion_server/tasks.py +++ b/ingestion_server/ingestion_server/tasks.py @@ -6,7 +6,6 @@ import logging from enum import Enum, auto from multiprocessing import Value -from typing import Optional from ingestion_server import slack from ingestion_server.constants.media_types import MediaType @@ -96,7 +95,7 @@ def serialize_task_info(task_info: dict) -> dict: :return: the details of the task to show to the user """ - def _time_fmt(timestamp: int) -> Optional[str]: + def _time_fmt(timestamp: int) -> str | None: """ Format the timestamp into a human-readable date and time notation. :param timestamp: the timestamp to format @@ -154,7 +153,7 @@ def perform_task( task_id: str, model: MediaType, action: TaskTypes, - callback_url: Optional[str], + callback_url: str | None, progress: Value, finish_time: Value, active_workers: Value, diff --git a/ingestion_server/test/gen_integration_compose.py b/ingestion_server/test/gen_integration_compose.py index 85acb6800..74f1282b4 100644 --- a/ingestion_server/test/gen_integration_compose.py +++ b/ingestion_server/test/gen_integration_compose.py @@ -128,7 +128,7 @@ def _rename_services(conf: dict): def gen_integration_compose(): print("Generating Docker Compose configuration for integration tests...") - with open(src_dc_path, "r") as src_dc: + with open(src_dc_path) as src_dc: conf = yaml.safe_load(src_dc) print("│ Pruning unwanted services... ", end="") diff --git a/ingestion_server/test/integration_test.py b/ingestion_server/test/integration_test.py index 338394a56..9e0105b02 100644 --- a/ingestion_server/test/integration_test.py +++ b/ingestion_server/test/integration_test.py @@ -145,7 +145,7 @@ def _load_schemas(cls, conn, schema_names): cur = conn.cursor() for schema_name in schema_names: schema_path = this_dir.joinpath("mock_schemas", f"{schema_name}.sql") - with open(schema_path, "r") as schema: + with open(schema_path) as schema: cur.execute(schema.read()) conn.commit() cur.close() @@ -155,7 +155,7 @@ def _load_data(cls, conn, table_names): cur = conn.cursor() for table_name in table_names: data_path = this_dir.joinpath("mock_data", f"mocked_{table_name}.csv") - with open(data_path, "r") as data: + with open(data_path) as data: cur.copy_expert( f"COPY {table_name} FROM STDIN WITH (FORMAT csv, HEADER true)", data, @@ -451,7 +451,7 @@ def test_upstream_indexed_images(self): es.indices.refresh(index="image-integration") count = es.count(index="image-integration")["count"] msg = "There should be 5000 images in Elasticsearch after ingestion." - self.assertEquals(count, 5000, msg) + self.assertEqual(count, 5000, msg) @pytest.mark.order(9) def test_upstream_indexed_audio(self): @@ -465,7 +465,7 @@ def test_upstream_indexed_audio(self): es.indices.refresh(index="audio-integration") count = es.count(index="audio-integration")["count"] msg = "There should be 5000 audio tracks in Elasticsearch after ingestion." - self.assertEquals(count, 5000, msg) + self.assertEqual(count, 5000, msg) @pytest.mark.order(10) def test_update_index_images(self):