Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Move AWS and DB clients to new app/clients folder #379

Merged
merged 9 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from sqlalchemy import Column, update

from app.api.api_v1.schemas.document import DocumentUpdateRequest
from app.clients.db.session import get_db
from app.core.auth import get_superuser_details
from app.core.lookups import get_family_document_by_import_id_or_slug
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm

from app.clients.db.session import get_db
from app.core.auth import authenticate_user
from app.core.security import create_access_token
from app.db.crud.user import get_app_user_authorisation
from app.db.session import get_db

auth_router = r = APIRouter()

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
FamilyAndDocumentsResponse,
FamilyDocumentWithContextResponse,
)
from app.clients.db.session import get_db
from app.db.crud.document import (
get_family_and_documents,
get_family_document_and_context,
get_slugged_objects,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__file__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/geographies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from fastapi import APIRouter, Depends, HTTPException, status

from app.api.api_v1.schemas.geography import GeographyStatsDTO
from app.clients.db.session import get_db
from app.db.crud.geography import get_world_map_stats
from app.db.session import get_db
from app.errors import RepositoryError

_LOGGER = logging.getLogger(__file__)
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/lookups/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from app.api.api_v1.routers.lookups.router import lookups_router
from app.api.api_v1.schemas.metadata import ApplicationConfig
from app.clients.db.session import get_db
from app.core.lookups import get_config
from app.db.session import get_db


@lookups_router.get("/config", response_model=ApplicationConfig)
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/lookups/geo_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy import exc

from app.api.api_v1.routers.lookups.router import lookups_router
from app.db.session import get_db
from app.clients.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions app/api/api_v1/routers/pipeline_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from sqlalchemy.orm import Session

from app.api.api_v1.schemas.document import BulkIngestResult
from app.clients.aws.client import S3Client, get_s3_client
from app.clients.db.session import get_db
from app.core.auth import get_superuser_details
from app.core.aws import S3Client, get_s3_client
from app.core.ingestion.pipeline import generate_pipeline_ingest_input
from app.core.validation.util import get_new_s3_prefix, write_documents_to_s3
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
7 changes: 4 additions & 3 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from starlette.responses import RedirectResponse

from app.api.api_v1.schemas.search import SearchRequestBody, SearchResponse
from app.core.aws import S3Client, S3Document, get_s3_client
from app.core.config import (
from app.clients.aws.client import S3Client, get_s3_client
from app.clients.aws.s3_document import S3Document
from app.clients.db.session import get_db
from app.config import (
AWS_REGION,
CDN_DOMAIN,
DOCUMENT_CACHE_BUCKET,
Expand All @@ -36,7 +38,6 @@
process_result_into_csv,
process_vespa_search_response,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status

from app.api.api_v1.schemas.search import GeographySummaryFamilyResponse
from app.clients.db.session import get_db
from app.core.browse import BrowseArgs, browse_rds_families
from app.core.lookups import get_country_slug_from_country_code, is_country_code
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
32 changes: 2 additions & 30 deletions app/core/aws.py → app/clients/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import os
import re
import typing as t
from datetime import datetime

Expand All @@ -11,39 +10,12 @@
from botocore.exceptions import ClientError, UnauthorizedSSOTokenError
from botocore.response import StreamingBody

from app.core.config import AWS_REGION, DEVELOPMENT_MODE
from app.clients.aws.s3_document import S3Document
from app.config import AWS_REGION, DEVELOPMENT_MODE

logger = logging.getLogger(__name__)


class S3Document:
"""A class representing an S3 document."""

def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107
self.bucket_name = bucket_name
self.region = region
self.key = key

@property
def url(self):
"""Return the URL for this S3 document."""
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}"

@classmethod
def from_url(cls, url: str) -> "S3Document":
"""
Create an S3 document from a URL.

:param [str] url: The URL of the document to create
:return [S3Document]: document representing given URL
"""
bucket_name, region, key = re.findall(
r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url
)[0]

return S3Document(bucket_name=bucket_name, region=region, key=key)


class S3Client:
"""Helper class to connect to S3 and perform actions on buckets and documents."""

Expand Down
34 changes: 34 additions & 0 deletions app/clients/aws/s3_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""AWS Helper classes."""

import logging
import re

logger = logging.getLogger(__name__)


class S3Document:
"""A class representing an S3 document."""

def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107
self.bucket_name = bucket_name
self.region = region
self.key = key

@property
def url(self):
"""Return the URL for this S3 document."""
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}"

@classmethod
def from_url(cls, url: str) -> "S3Document":
"""
Create an S3 document from a URL.

:param [str] url: The URL of the document to create
:return [S3Document]: document representing given URL
"""
bucket_name, region, key = re.findall(
r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url
)[0]

return S3Document(bucket_name=bucket_name, region=region, key=key)
2 changes: 1 addition & 1 deletion app/db/session.py → app/clients/db/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from app.core import config
from app import config

engine = create_engine(
config.SQLALCHEMY_DATABASE_URI,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion app/core/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from fastapi import Depends

from app.db.session import get_db
from app.clients.db.session import get_db

_LOGGER = getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions app/core/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy.orm import Session

from app.api.api_v1.routers.search import _VESPA_CONNECTION
from app.core.config import DEVELOPMENT_MODE
from app.db.session import get_db
from app.clients.db.session import get_db
from app.config import DEVELOPMENT_MODE

_LOGGER = logging.getLogger(__file__)

Expand Down
2 changes: 1 addition & 1 deletion app/core/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from sqlalchemy import func
from sqlalchemy.orm import Session

from app import config
from app.api.api_v1.schemas.metadata import CorpusData, OrganisationConfig
from app.core import config


def _to_corpus_data(row) -> CorpusData:
Expand Down
2 changes: 1 addition & 1 deletion app/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SearchResponseFamily,
SearchResponseFamilyDocument,
)
from app.core.config import PUBLIC_APP_URL
from app.config import PUBLIC_APP_URL
from app.core.lookups import (
doc_type_from_family_document_metadata,
get_countries_for_region,
Expand Down
5 changes: 3 additions & 2 deletions app/core/validation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import Any, Collection, Mapping, Optional, Sequence, Union

from app.api.api_v1.schemas.document import DocumentParserInput
from app.core.aws import S3Client, S3Document
from app.core.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET
from app.clients.aws.client import S3Client
from app.clients.aws.s3_document import S3Document
from app.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET

_LOGGER = logging.getLogger(__file__)

Expand Down
Empty file removed app/db/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi_pagination import add_pagination
from starlette.requests import Request

from app import config
from app.api.api_v1.routers.admin import admin_document_router
from app.api.api_v1.routers.auth import auth_router
from app.api.api_v1.routers.documents import documents_router
Expand All @@ -20,10 +21,9 @@
from app.api.api_v1.routers.pipeline_trigger import pipeline_trigger_router
from app.api.api_v1.routers.search import search_router
from app.api.api_v1.routers.summaries import summary_router
from app.core import config
from app.clients.db.session import SessionLocal, engine
from app.core.auth import get_superuser_details
from app.core.health import is_database_online
from app.db.session import SessionLocal, engine

os.environ["SKIP_ALEMBIC_LOGGING"] = "1"

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import create_database, database_exists, drop_database

from app.clients.aws.client import S3Client, get_s3_client
from app.clients.db.session import get_db
from app.core import custom_app, security
from app.core.aws import S3Client, get_s3_client
from app.core.custom_app import AppTokenFactory
from app.db.session import get_db
from app.main import app


Expand Down
2 changes: 1 addition & 1 deletion tests/non_search/app/lookups/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
from db_client.models.organisation import Corpus, Organisation

from app.clients.db.session import SessionLocal
from app.core.util import tree_table_to_json
from app.db.session import SessionLocal

LEN_ORG_CONFIG = 3
EXPECTED_CCLW_TAXONOMY = {
Expand Down
2 changes: 1 addition & 1 deletion tests/non_search/core/validation/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from app.api.api_v1.schemas.document import DocumentParserInput
from app.core.config import PIPELINE_BUCKET
from app.config import PIPELINE_BUCKET
from app.core.validation import IMPORT_ID_MATCHER
from app.core.validation.util import _flatten_maybe_tree, write_documents_to_s3

Expand Down
8 changes: 4 additions & 4 deletions tests/search/vespa/test_whole_database_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_whole_database_download_fails_when_decoding_token_raises_PyJWTError(
), patch(
"app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"
), patch(
"app.core.aws.S3Client.is_connected", return_value=True
"app.clients.aws.client.S3Client.is_connected", return_value=True
):
response = data_client.get(
ALL_DATA_DOWNLOAD_ENDPOINT,
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_whole_database_download_fails_when_corpus_ids_in_token_not_in_db(
"app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db",
return_value=False,
), patch(
"app.core.aws.S3Client.is_connected", return_value=True
"app.clients.aws.client.S3Client.is_connected", return_value=True
), patch(
"app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"
), patch(
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_all_data_download(mock_corpora_exist_in_db, data_db, data_client, valid
with (
patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"),
patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"),
patch("app.core.aws.S3Client.is_connected", return_value=True),
patch("app.clients.aws.client.S3Client.is_connected", return_value=True),
):
data_client.follow_redirects = False
download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT, headers=headers)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_all_data_download_fails_when_s3_upload_failed(
with (
patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"),
patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"),
patch("app.core.aws.S3Client.is_connected", return_value=True),
patch("app.clients.aws.client.S3Client.is_connected", return_value=True),
patch(
"app.api.api_v1.routers.search._get_s3_doc_url_from_cdn", return_value=None
),
Expand Down
Loading