diff --git a/app/api/api_v1/routers/admin.py b/app/api/api_v1/routers/admin.py index e6d3a066..c6e91eb5 100644 --- a/app/api/api_v1/routers/admin.py +++ b/app/api/api_v1/routers/admin.py @@ -12,9 +12,9 @@ from sqlalchemy import Column, update from app.api.api_v1.schemas.document import DocumentUpdateRequest +from app.clients.db.session import get_db from app.core.auth import get_superuser_details from app.core.lookups import get_family_document_by_import_id_or_slug -from app.db.session import get_db _LOGGER = logging.getLogger(__name__) diff --git a/app/api/api_v1/routers/auth.py b/app/api/api_v1/routers/auth.py index 16801d69..3a89c987 100644 --- a/app/api/api_v1/routers/auth.py +++ b/app/api/api_v1/routers/auth.py @@ -4,10 +4,10 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm +from app.clients.db.session import get_db from app.core.auth import authenticate_user from app.core.security import create_access_token from app.db.crud.user import get_app_user_authorisation -from app.db.session import get_db auth_router = r = APIRouter() diff --git a/app/api/api_v1/routers/documents.py b/app/api/api_v1/routers/documents.py index c8c8d78a..e985d762 100644 --- a/app/api/api_v1/routers/documents.py +++ b/app/api/api_v1/routers/documents.py @@ -8,12 +8,12 @@ FamilyAndDocumentsResponse, FamilyDocumentWithContextResponse, ) +from app.clients.db.session import get_db from app.db.crud.document import ( get_family_and_documents, get_family_document_and_context, get_slugged_objects, ) -from app.db.session import get_db _LOGGER = logging.getLogger(__file__) diff --git a/app/api/api_v1/routers/geographies.py b/app/api/api_v1/routers/geographies.py index be21dff9..a6d2ac7a 100644 --- a/app/api/api_v1/routers/geographies.py +++ b/app/api/api_v1/routers/geographies.py @@ -3,8 +3,8 @@ from fastapi import APIRouter, Depends, HTTPException, status from app.api.api_v1.schemas.geography import GeographyStatsDTO +from app.clients.db.session import get_db from app.db.crud.geography import get_world_map_stats -from app.db.session import get_db from app.errors import RepositoryError _LOGGER = logging.getLogger(__file__) diff --git a/app/api/api_v1/routers/lookups/config.py b/app/api/api_v1/routers/lookups/config.py index 42493ca9..fcad6d91 100644 --- a/app/api/api_v1/routers/lookups/config.py +++ b/app/api/api_v1/routers/lookups/config.py @@ -2,8 +2,8 @@ from app.api.api_v1.routers.lookups.router import lookups_router from app.api.api_v1.schemas.metadata import ApplicationConfig +from app.clients.db.session import get_db from app.core.lookups import get_config -from app.db.session import get_db @lookups_router.get("/config", response_model=ApplicationConfig) diff --git a/app/api/api_v1/routers/lookups/geo_stats.py b/app/api/api_v1/routers/lookups/geo_stats.py index 00f60488..397bb1b0 100644 --- a/app/api/api_v1/routers/lookups/geo_stats.py +++ b/app/api/api_v1/routers/lookups/geo_stats.py @@ -8,7 +8,7 @@ from sqlalchemy import exc from app.api.api_v1.routers.lookups.router import lookups_router -from app.db.session import get_db +from app.clients.db.session import get_db _LOGGER = logging.getLogger(__name__) diff --git a/app/api/api_v1/routers/pipeline_trigger.py b/app/api/api_v1/routers/pipeline_trigger.py index 5f34cd91..8bf04cca 100644 --- a/app/api/api_v1/routers/pipeline_trigger.py +++ b/app/api/api_v1/routers/pipeline_trigger.py @@ -4,11 +4,11 @@ from sqlalchemy.orm import Session from app.api.api_v1.schemas.document import BulkIngestResult +from app.clients.aws.client import S3Client, get_s3_client +from app.clients.db.session import get_db from app.core.auth import get_superuser_details -from app.core.aws import S3Client, get_s3_client from app.core.ingestion.pipeline import generate_pipeline_ingest_input from app.core.validation.util import get_new_s3_prefix, write_documents_to_s3 -from app.db.session import get_db _LOGGER = logging.getLogger(__name__) diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 34f6e8b1..8652cd22 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -18,8 +18,10 @@ from starlette.responses import RedirectResponse from app.api.api_v1.schemas.search import SearchRequestBody, SearchResponse -from app.core.aws import S3Client, S3Document, get_s3_client -from app.core.config import ( +from app.clients.aws.client import S3Client, get_s3_client +from app.clients.aws.s3_document import S3Document +from app.clients.db.session import get_db +from app.config import ( AWS_REGION, CDN_DOMAIN, DOCUMENT_CACHE_BUCKET, @@ -36,7 +38,6 @@ process_result_into_csv, process_vespa_search_response, ) -from app.db.session import get_db _LOGGER = logging.getLogger(__name__) diff --git a/app/api/api_v1/routers/summaries.py b/app/api/api_v1/routers/summaries.py index 56009c7b..3d9481a1 100644 --- a/app/api/api_v1/routers/summaries.py +++ b/app/api/api_v1/routers/summaries.py @@ -10,9 +10,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from app.api.api_v1.schemas.search import GeographySummaryFamilyResponse +from app.clients.db.session import get_db from app.core.browse import BrowseArgs, browse_rds_families from app.core.lookups import get_country_slug_from_country_code, is_country_code -from app.db.session import get_db _LOGGER = logging.getLogger(__name__) diff --git a/app/core/aws.py b/app/clients/aws/client.py similarity index 92% rename from app/core/aws.py rename to app/clients/aws/client.py index cdfbc5c7..e484bf73 100644 --- a/app/core/aws.py +++ b/app/clients/aws/client.py @@ -2,7 +2,6 @@ import logging import os -import re import typing as t from datetime import datetime @@ -11,39 +10,12 @@ from botocore.exceptions import ClientError, UnauthorizedSSOTokenError from botocore.response import StreamingBody -from app.core.config import AWS_REGION, DEVELOPMENT_MODE +from app.clients.aws.s3_document import S3Document +from app.config import AWS_REGION, DEVELOPMENT_MODE logger = logging.getLogger(__name__) -class S3Document: - """A class representing an S3 document.""" - - def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107 - self.bucket_name = bucket_name - self.region = region - self.key = key - - @property - def url(self): - """Return the URL for this S3 document.""" - return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}" - - @classmethod - def from_url(cls, url: str) -> "S3Document": - """ - Create an S3 document from a URL. - - :param [str] url: The URL of the document to create - :return [S3Document]: document representing given URL - """ - bucket_name, region, key = re.findall( - r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url - )[0] - - return S3Document(bucket_name=bucket_name, region=region, key=key) - - class S3Client: """Helper class to connect to S3 and perform actions on buckets and documents.""" diff --git a/app/clients/aws/s3_document.py b/app/clients/aws/s3_document.py new file mode 100644 index 00000000..f4ad4bdd --- /dev/null +++ b/app/clients/aws/s3_document.py @@ -0,0 +1,34 @@ +"""AWS Helper classes.""" + +import logging +import re + +logger = logging.getLogger(__name__) + + +class S3Document: + """A class representing an S3 document.""" + + def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107 + self.bucket_name = bucket_name + self.region = region + self.key = key + + @property + def url(self): + """Return the URL for this S3 document.""" + return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}" + + @classmethod + def from_url(cls, url: str) -> "S3Document": + """ + Create an S3 document from a URL. + + :param [str] url: The URL of the document to create + :return [S3Document]: document representing given URL + """ + bucket_name, region, key = re.findall( + r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url + )[0] + + return S3Document(bucket_name=bucket_name, region=region, key=key) diff --git a/app/db/session.py b/app/clients/db/session.py similarity index 96% rename from app/db/session.py rename to app/clients/db/session.py index 2898413a..de54c2ee 100644 --- a/app/db/session.py +++ b/app/clients/db/session.py @@ -1,7 +1,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from app.core import config +from app import config engine = create_engine( config.SQLALCHEMY_DATABASE_URI, diff --git a/app/core/config.py b/app/config.py similarity index 100% rename from app/core/config.py rename to app/config.py diff --git a/app/core/download.py b/app/core/download.py index 485f38ee..f41caf07 100644 --- a/app/core/download.py +++ b/app/core/download.py @@ -9,7 +9,7 @@ import pandas as pd from fastapi import Depends -from app.db.session import get_db +from app.clients.db.session import get_db _LOGGER = getLogger(__name__) diff --git a/app/core/health.py b/app/core/health.py index 91d2f5a9..ec309a45 100644 --- a/app/core/health.py +++ b/app/core/health.py @@ -4,8 +4,8 @@ from sqlalchemy.orm import Session from app.api.api_v1.routers.search import _VESPA_CONNECTION -from app.core.config import DEVELOPMENT_MODE -from app.db.session import get_db +from app.clients.db.session import get_db +from app.config import DEVELOPMENT_MODE _LOGGER = logging.getLogger(__file__) diff --git a/app/core/organisation.py b/app/core/organisation.py index 4ee9e132..0e915fc8 100644 --- a/app/core/organisation.py +++ b/app/core/organisation.py @@ -5,8 +5,8 @@ from sqlalchemy import func from sqlalchemy.orm import Session +from app import config from app.api.api_v1.schemas.metadata import CorpusData, OrganisationConfig -from app.core import config def _to_corpus_data(row) -> CorpusData: diff --git a/app/core/search.py b/app/core/search.py index afd343f3..ca1b9741 100644 --- a/app/core/search.py +++ b/app/core/search.py @@ -37,7 +37,7 @@ SearchResponseFamily, SearchResponseFamilyDocument, ) -from app.core.config import PUBLIC_APP_URL +from app.config import PUBLIC_APP_URL from app.core.lookups import ( doc_type_from_family_document_metadata, get_countries_for_region, diff --git a/app/core/validation/util.py b/app/core/validation/util.py index 5cb7055d..52e6ae3b 100644 --- a/app/core/validation/util.py +++ b/app/core/validation/util.py @@ -7,8 +7,9 @@ from typing import Any, Collection, Mapping, Optional, Sequence, Union from app.api.api_v1.schemas.document import DocumentParserInput -from app.core.aws import S3Client, S3Document -from app.core.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET +from app.clients.aws.client import S3Client +from app.clients.aws.s3_document import S3Document +from app.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET _LOGGER = logging.getLogger(__file__) diff --git a/app/db/__init__.py b/app/db/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/main.py b/app/main.py index 4b770054..104e5493 100644 --- a/app/main.py +++ b/app/main.py @@ -12,6 +12,7 @@ from fastapi_pagination import add_pagination from starlette.requests import Request +from app import config from app.api.api_v1.routers.admin import admin_document_router from app.api.api_v1.routers.auth import auth_router from app.api.api_v1.routers.documents import documents_router @@ -20,10 +21,9 @@ from app.api.api_v1.routers.pipeline_trigger import pipeline_trigger_router from app.api.api_v1.routers.search import search_router from app.api.api_v1.routers.summaries import summary_router -from app.core import config +from app.clients.db.session import SessionLocal, engine from app.core.auth import get_superuser_details from app.core.health import is_database_online -from app.db.session import SessionLocal, engine os.environ["SKIP_ALEMBIC_LOGGING"] = "1" diff --git a/tests/conftest.py b/tests/conftest.py index c205b18b..ad2d15e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,10 +15,10 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import create_database, database_exists, drop_database +from app.clients.aws.client import S3Client, get_s3_client +from app.clients.db.session import get_db from app.core import custom_app, security -from app.core.aws import S3Client, get_s3_client from app.core.custom_app import AppTokenFactory -from app.db.session import get_db from app.main import app diff --git a/tests/non_search/app/lookups/test_config.py b/tests/non_search/app/lookups/test_config.py index 62052c3f..ca9211a4 100644 --- a/tests/non_search/app/lookups/test_config.py +++ b/tests/non_search/app/lookups/test_config.py @@ -11,8 +11,8 @@ ) from db_client.models.organisation import Corpus, Organisation +from app.clients.db.session import SessionLocal from app.core.util import tree_table_to_json -from app.db.session import SessionLocal LEN_ORG_CONFIG = 3 EXPECTED_CCLW_TAXONOMY = { diff --git a/tests/non_search/core/validation/test_util.py b/tests/non_search/core/validation/test_util.py index 744934fa..1a40ac15 100644 --- a/tests/non_search/core/validation/test_util.py +++ b/tests/non_search/core/validation/test_util.py @@ -6,7 +6,7 @@ import pytest from app.api.api_v1.schemas.document import DocumentParserInput -from app.core.config import PIPELINE_BUCKET +from app.config import PIPELINE_BUCKET from app.core.validation import IMPORT_ID_MATCHER from app.core.validation.util import _flatten_maybe_tree, write_documents_to_s3 diff --git a/tests/search/vespa/test_whole_database_download.py b/tests/search/vespa/test_whole_database_download.py index 024d4d50..52346223 100644 --- a/tests/search/vespa/test_whole_database_download.py +++ b/tests/search/vespa/test_whole_database_download.py @@ -38,7 +38,7 @@ def test_whole_database_download_fails_when_decoding_token_raises_PyJWTError( ), patch( "app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket" ), patch( - "app.core.aws.S3Client.is_connected", return_value=True + "app.clients.aws.client.S3Client.is_connected", return_value=True ): response = data_client.get( ALL_DATA_DOWNLOAD_ENDPOINT, @@ -68,7 +68,7 @@ def test_whole_database_download_fails_when_corpus_ids_in_token_not_in_db( "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", return_value=False, ), patch( - "app.core.aws.S3Client.is_connected", return_value=True + "app.clients.aws.client.S3Client.is_connected", return_value=True ), patch( "app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket" ), patch( @@ -96,7 +96,7 @@ def test_all_data_download(mock_corpora_exist_in_db, data_db, data_client, valid with ( patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"), patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"), - patch("app.core.aws.S3Client.is_connected", return_value=True), + patch("app.clients.aws.client.S3Client.is_connected", return_value=True), ): data_client.follow_redirects = False download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT, headers=headers) @@ -126,7 +126,7 @@ def test_all_data_download_fails_when_s3_upload_failed( with ( patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"), patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"), - patch("app.core.aws.S3Client.is_connected", return_value=True), + patch("app.clients.aws.client.S3Client.is_connected", return_value=True), patch( "app.api.api_v1.routers.search._get_s3_doc_url_from_cdn", return_value=None ),