From aa535d25a9d8416b7565654af7a05936ffef3de9 Mon Sep 17 00:00:00 2001 From: Anna Pokorska Date: Tue, 3 Dec 2024 17:04:57 +0000 Subject: [PATCH] Return all orgs if allowed corpora list is empty --- app/repository/organisation.py | 10 +-- app/service/custom_app.py | 1 + .../non_search/routers/lookups/test_config.py | 64 +++++++++++++++++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/app/repository/organisation.py b/app/repository/organisation.py index 1a710821..f2947066 100644 --- a/app/repository/organisation.py +++ b/app/repository/organisation.py @@ -88,9 +88,9 @@ def get_organisation_config(db: Session, org: Organisation) -> OrganisationConfi def get_organisations(db: Session, allowed_corpora: list[str]) -> list[Organisation]: - return ( - db.query(Organisation) - .join(Corpus, Corpus.organisation_id == Organisation.id) - .filter(Corpus.import_id.in_(allowed_corpora)) - .all() + query = db.query(Organisation).join( + Corpus, Corpus.organisation_id == Organisation.id ) + if allowed_corpora != []: + query = query.filter(Corpus.import_id.in_(allowed_corpora)) + return query.all() diff --git a/app/service/custom_app.py b/app/service/custom_app.py index 6e89dd63..82cb791c 100644 --- a/app/service/custom_app.py +++ b/app/service/custom_app.py @@ -260,4 +260,5 @@ def decode_and_validate( # First corpora validation is app token against DB. At least one of the app token # corpora IDs must be present in the DB to continue the search request. + any_exist = False if not self.allowed_corpora_ids else True self.validate(db, any_exist) diff --git a/tests/non_search/routers/lookups/test_config.py b/tests/non_search/routers/lookups/test_config.py index 774157d1..734bc835 100644 --- a/tests/non_search/routers/lookups/test_config.py +++ b/tests/non_search/routers/lookups/test_config.py @@ -1,8 +1,12 @@ +import os +from datetime import datetime from http.client import OK from typing import Any from unittest.mock import MagicMock +import jwt import pytest +from dateutil.relativedelta import relativedelta from db_client.models.dfce.family import ( Family, FamilyCategory, @@ -12,6 +16,7 @@ from db_client.models.organisation import Corpus, CorpusType, Organisation from app.clients.db.session import SessionLocal +from app.service import security from app.service.util import tree_table_to_json LEN_ORG_CONFIG = 3 @@ -285,6 +290,65 @@ def test_config_endpoint_returns_stats_for_allowed_corpora_only( assert org_config == expected_org_config +def test_config_endpoint_returns_stats_for_all_orgs_if_no_allowed_corpora_in_app_token( + data_client, + data_db, +): + issued_at = datetime.utcnow() + to_encode = { + "allowed_corpora_ids": [], + "exp": issued_at + relativedelta(years=10), + "iat": int(datetime.timestamp(issued_at.replace(microsecond=0))), + "iss": "Climate Policy Radar", + "sub": "CPR", + "aud": "localhost", + } + app_token = jwt.encode( + to_encode, os.environ["TOKEN_SECRET_KEY"], algorithm=security.ALGORITHM + ) + # app_token = app_token_factory(",") + url_under_test = "/api/v1/config" + + cclw_corpus = ( + data_db.query(Corpus) + .join(Organisation, Organisation.id == Corpus.organisation_id) + .filter(Organisation.name == "CCLW") + .one() + ) + + unfccc_corpus = ( + data_db.query(Corpus) + .join(Organisation, Organisation.id == Corpus.organisation_id) + .filter(Organisation.name == "UNFCCC") + .one() + ) + + _add_family(data_db, "T.0.0.1", FamilyCategory.EXECUTIVE, cclw_corpus.import_id) + _add_family(data_db, "T.0.0.2", FamilyCategory.LEGISLATIVE, unfccc_corpus.import_id) + data_db.flush() + + response = data_client.get(url_under_test, headers={"app-token": app_token}) + + response_json = response.json() + org_config = response_json["organisations"] + + assert list(org_config.keys()) == ["CCLW", "UNFCCC"] + assert org_config["CCLW"]["total"] == 1 + assert org_config["UNFCCC"]["total"] == 1 + assert org_config["UNFCCC"]["count_by_category"] == { + "Executive": 0, + "Legislative": 1, + "MCF": 0, + "UNFCCC": 0, + } + assert org_config["CCLW"]["count_by_category"] == { + "Executive": 1, + "Legislative": 0, + "MCF": 0, + "UNFCCC": 0, + } + + class _MockColumn: def __init__(self, name): self.name = name