Skip to content

Commit

Permalink
Refactor: Move AWS and DB clients to new app/clients folder (#379)
Browse files Browse the repository at this point in the history
* Move session.py under clients/db

* Move config to app root

* Move config to app root

* Move AWS under clients/aws

* Separate AWS clients into client and s3doc

* Rename from s3 to s3_document

* Attempt to fix patch path for AWS client

* Fix patch path for AWS client
  • Loading branch information
katybaulch authored Oct 14, 2024
1 parent f9b9cee commit 56e0275
Show file tree
Hide file tree
Showing 24 changed files with 68 additions and 60 deletions.
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from sqlalchemy import Column, update

from app.api.api_v1.schemas.document import DocumentUpdateRequest
from app.clients.db.session import get_db
from app.core.auth import get_superuser_details
from app.core.lookups import get_family_document_by_import_id_or_slug
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm

from app.clients.db.session import get_db
from app.core.auth import authenticate_user
from app.core.security import create_access_token
from app.db.crud.user import get_app_user_authorisation
from app.db.session import get_db

auth_router = r = APIRouter()

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
FamilyAndDocumentsResponse,
FamilyDocumentWithContextResponse,
)
from app.clients.db.session import get_db
from app.db.crud.document import (
get_family_and_documents,
get_family_document_and_context,
get_slugged_objects,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__file__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/geographies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from fastapi import APIRouter, Depends, HTTPException, status

from app.api.api_v1.schemas.geography import GeographyStatsDTO
from app.clients.db.session import get_db
from app.db.crud.geography import get_world_map_stats
from app.db.session import get_db
from app.errors import RepositoryError

_LOGGER = logging.getLogger(__file__)
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/lookups/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from app.api.api_v1.routers.lookups.router import lookups_router
from app.api.api_v1.schemas.metadata import ApplicationConfig
from app.clients.db.session import get_db
from app.core.lookups import get_config
from app.db.session import get_db


@lookups_router.get("/config", response_model=ApplicationConfig)
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/lookups/geo_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy import exc

from app.api.api_v1.routers.lookups.router import lookups_router
from app.db.session import get_db
from app.clients.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions app/api/api_v1/routers/pipeline_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from sqlalchemy.orm import Session

from app.api.api_v1.schemas.document import BulkIngestResult
from app.clients.aws.client import S3Client, get_s3_client
from app.clients.db.session import get_db
from app.core.auth import get_superuser_details
from app.core.aws import S3Client, get_s3_client
from app.core.ingestion.pipeline import generate_pipeline_ingest_input
from app.core.validation.util import get_new_s3_prefix, write_documents_to_s3
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
7 changes: 4 additions & 3 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from starlette.responses import RedirectResponse

from app.api.api_v1.schemas.search import SearchRequestBody, SearchResponse
from app.core.aws import S3Client, S3Document, get_s3_client
from app.core.config import (
from app.clients.aws.client import S3Client, get_s3_client
from app.clients.aws.s3_document import S3Document
from app.clients.db.session import get_db
from app.config import (
AWS_REGION,
CDN_DOMAIN,
DOCUMENT_CACHE_BUCKET,
Expand All @@ -36,7 +38,6 @@
process_result_into_csv,
process_vespa_search_response,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status

from app.api.api_v1.schemas.search import GeographySummaryFamilyResponse
from app.clients.db.session import get_db
from app.core.browse import BrowseArgs, browse_rds_families
from app.core.lookups import get_country_slug_from_country_code, is_country_code
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)

Expand Down
32 changes: 2 additions & 30 deletions app/core/aws.py → app/clients/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import os
import re
import typing as t
from datetime import datetime

Expand All @@ -11,39 +10,12 @@
from botocore.exceptions import ClientError, UnauthorizedSSOTokenError
from botocore.response import StreamingBody

from app.core.config import AWS_REGION, DEVELOPMENT_MODE
from app.clients.aws.s3_document import S3Document
from app.config import AWS_REGION, DEVELOPMENT_MODE

logger = logging.getLogger(__name__)


class S3Document:
"""A class representing an S3 document."""

def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107
self.bucket_name = bucket_name
self.region = region
self.key = key

@property
def url(self):
"""Return the URL for this S3 document."""
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}"

@classmethod
def from_url(cls, url: str) -> "S3Document":
"""
Create an S3 document from a URL.
:param [str] url: The URL of the document to create
:return [S3Document]: document representing given URL
"""
bucket_name, region, key = re.findall(
r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url
)[0]

return S3Document(bucket_name=bucket_name, region=region, key=key)


class S3Client:
"""Helper class to connect to S3 and perform actions on buckets and documents."""

Expand Down
34 changes: 34 additions & 0 deletions app/clients/aws/s3_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""AWS Helper classes."""

import logging
import re

logger = logging.getLogger(__name__)


class S3Document:
"""A class representing an S3 document."""

def __init__(self, bucket_name: str, region: str, key: str): # noqa: D107
self.bucket_name = bucket_name
self.region = region
self.key = key

@property
def url(self):
"""Return the URL for this S3 document."""
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{self.key}"

@classmethod
def from_url(cls, url: str) -> "S3Document":
"""
Create an S3 document from a URL.
:param [str] url: The URL of the document to create
:return [S3Document]: document representing given URL
"""
bucket_name, region, key = re.findall(
r"https:\/\/([\w-]+).s3.([\w-]+).amazonaws.com\/([\w.-]+)", url
)[0]

return S3Document(bucket_name=bucket_name, region=region, key=key)
2 changes: 1 addition & 1 deletion app/db/session.py → app/clients/db/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from app.core import config
from app import config

engine = create_engine(
config.SQLALCHEMY_DATABASE_URI,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion app/core/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from fastapi import Depends

from app.db.session import get_db
from app.clients.db.session import get_db

_LOGGER = getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions app/core/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy.orm import Session

from app.api.api_v1.routers.search import _VESPA_CONNECTION
from app.core.config import DEVELOPMENT_MODE
from app.db.session import get_db
from app.clients.db.session import get_db
from app.config import DEVELOPMENT_MODE

_LOGGER = logging.getLogger(__file__)

Expand Down
2 changes: 1 addition & 1 deletion app/core/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from sqlalchemy import func
from sqlalchemy.orm import Session

from app import config
from app.api.api_v1.schemas.metadata import CorpusData, OrganisationConfig
from app.core import config


def _to_corpus_data(row) -> CorpusData:
Expand Down
2 changes: 1 addition & 1 deletion app/core/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SearchResponseFamily,
SearchResponseFamilyDocument,
)
from app.core.config import PUBLIC_APP_URL
from app.config import PUBLIC_APP_URL
from app.core.lookups import (
doc_type_from_family_document_metadata,
get_countries_for_region,
Expand Down
5 changes: 3 additions & 2 deletions app/core/validation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import Any, Collection, Mapping, Optional, Sequence, Union

from app.api.api_v1.schemas.document import DocumentParserInput
from app.core.aws import S3Client, S3Document
from app.core.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET
from app.clients.aws.client import S3Client
from app.clients.aws.s3_document import S3Document
from app.config import INGEST_TRIGGER_ROOT, PIPELINE_BUCKET

_LOGGER = logging.getLogger(__file__)

Expand Down
Empty file removed app/db/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi_pagination import add_pagination
from starlette.requests import Request

from app import config
from app.api.api_v1.routers.admin import admin_document_router
from app.api.api_v1.routers.auth import auth_router
from app.api.api_v1.routers.documents import documents_router
Expand All @@ -20,10 +21,9 @@
from app.api.api_v1.routers.pipeline_trigger import pipeline_trigger_router
from app.api.api_v1.routers.search import search_router
from app.api.api_v1.routers.summaries import summary_router
from app.core import config
from app.clients.db.session import SessionLocal, engine
from app.core.auth import get_superuser_details
from app.core.health import is_database_online
from app.db.session import SessionLocal, engine

os.environ["SKIP_ALEMBIC_LOGGING"] = "1"

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import create_database, database_exists, drop_database

from app.clients.aws.client import S3Client, get_s3_client
from app.clients.db.session import get_db
from app.core import custom_app, security
from app.core.aws import S3Client, get_s3_client
from app.core.custom_app import AppTokenFactory
from app.db.session import get_db
from app.main import app


Expand Down
2 changes: 1 addition & 1 deletion tests/non_search/app/lookups/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
from db_client.models.organisation import Corpus, Organisation

from app.clients.db.session import SessionLocal
from app.core.util import tree_table_to_json
from app.db.session import SessionLocal

LEN_ORG_CONFIG = 3
EXPECTED_CCLW_TAXONOMY = {
Expand Down
2 changes: 1 addition & 1 deletion tests/non_search/core/validation/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from app.api.api_v1.schemas.document import DocumentParserInput
from app.core.config import PIPELINE_BUCKET
from app.config import PIPELINE_BUCKET
from app.core.validation import IMPORT_ID_MATCHER
from app.core.validation.util import _flatten_maybe_tree, write_documents_to_s3

Expand Down
8 changes: 4 additions & 4 deletions tests/search/vespa/test_whole_database_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_whole_database_download_fails_when_decoding_token_raises_PyJWTError(
), patch(
"app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"
), patch(
"app.core.aws.S3Client.is_connected", return_value=True
"app.clients.aws.client.S3Client.is_connected", return_value=True
):
response = data_client.get(
ALL_DATA_DOWNLOAD_ENDPOINT,
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_whole_database_download_fails_when_corpus_ids_in_token_not_in_db(
"app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db",
return_value=False,
), patch(
"app.core.aws.S3Client.is_connected", return_value=True
"app.clients.aws.client.S3Client.is_connected", return_value=True
), patch(
"app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"
), patch(
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_all_data_download(mock_corpora_exist_in_db, data_db, data_client, valid
with (
patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"),
patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"),
patch("app.core.aws.S3Client.is_connected", return_value=True),
patch("app.clients.aws.client.S3Client.is_connected", return_value=True),
):
data_client.follow_redirects = False
download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT, headers=headers)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_all_data_download_fails_when_s3_upload_failed(
with (
patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"),
patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"),
patch("app.core.aws.S3Client.is_connected", return_value=True),
patch("app.clients.aws.client.S3Client.is_connected", return_value=True),
patch(
"app.api.api_v1.routers.search._get_s3_doc_url_from_cdn", return_value=None
),
Expand Down

0 comments on commit 56e0275

Please sign in to comment.