Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config endpoint returns only orgs that correspond to allowed corpora #426

Merged
10 changes: 5 additions & 5 deletions app/repository/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def get_organisation_config(db: Session, org: Organisation) -> OrganisationConfi


def get_organisations(db: Session, allowed_corpora: list[str]) -> list[Organisation]:
return (
db.query(Organisation)
.join(Corpus, Corpus.organisation_id == Organisation.id)
.filter(Corpus.import_id.in_(allowed_corpora))
.all()
query = db.query(Organisation).join(
Corpus, Corpus.organisation_id == Organisation.id
)
if allowed_corpora != []:
query = query.filter(Corpus.import_id.in_(allowed_corpora))
return query.all()
1 change: 1 addition & 0 deletions app/service/custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,5 @@ def decode_and_validate(

# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
any_exist = False if not self.allowed_corpora_ids else True
self.validate(db, any_exist)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def app_token_factory(monkeypatch):
def mock_return(_, __, ___):
return True

def _app_token(allowed_corpora_ids: list[str]):
def _app_token(allowed_corpora_ids):
subject = "CCLW"
audience = "localhost"
input_str = f"{allowed_corpora_ids};{subject};{audience}"
Expand Down
68 changes: 66 additions & 2 deletions tests/non_search/routers/lookups/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from datetime import datetime
from http.client import OK
from typing import Any
from unittest.mock import MagicMock

import jwt
import pytest
from dateutil.relativedelta import relativedelta
from db_client.models.dfce.family import (
Family,
FamilyCategory,
Expand All @@ -12,6 +16,7 @@
from db_client.models.organisation import Corpus, CorpusType, Organisation

from app.clients.db.session import SessionLocal
from app.service import security
from app.service.util import tree_table_to_json

LEN_ORG_CONFIG = 3
Expand Down Expand Up @@ -210,8 +215,8 @@ def test_config_endpoint_cclw_stats(data_client, data_db, valid_token):
@pytest.mark.parametrize(
"allowed_corpora_ids, expected_organisation, other_organisation",
[
(["UNFCCC.corpus.i00000001.n0000"], "UNFCCC", "CCLW"),
(["CCLW.corpus.i00000001.n0000"], "CCLW", "UNFCCC"),
("UNFCCC.corpus.i00000001.n0000", "UNFCCC", "CCLW"),
("CCLW.corpus.i00000001.n0000", "CCLW", "UNFCCC"),
],
)
def test_config_endpoint_returns_stats_for_allowed_corpora_only(
Expand Down Expand Up @@ -285,6 +290,65 @@ def test_config_endpoint_returns_stats_for_allowed_corpora_only(
assert org_config == expected_org_config


def test_config_endpoint_returns_stats_for_all_orgs_if_no_allowed_corpora_in_app_token(
data_client,
data_db,
):
issued_at = datetime.utcnow()
to_encode = {
"allowed_corpora_ids": [],
"exp": issued_at + relativedelta(years=10),
"iat": int(datetime.timestamp(issued_at.replace(microsecond=0))),
"iss": "Climate Policy Radar",
"sub": "CPR",
"aud": "localhost",
}
app_token = jwt.encode(
to_encode, os.environ["TOKEN_SECRET_KEY"], algorithm=security.ALGORITHM
)
# app_token = app_token_factory(",")
annaCPR marked this conversation as resolved.
Show resolved Hide resolved
url_under_test = "/api/v1/config"

cclw_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == "CCLW")
.one()
)

unfccc_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == "UNFCCC")
.one()
)

_add_family(data_db, "T.0.0.1", FamilyCategory.EXECUTIVE, cclw_corpus.import_id)
_add_family(data_db, "T.0.0.2", FamilyCategory.LEGISLATIVE, unfccc_corpus.import_id)
data_db.flush()

response = data_client.get(url_under_test, headers={"app-token": app_token})

response_json = response.json()
org_config = response_json["organisations"]

assert list(org_config.keys()) == ["CCLW", "UNFCCC"]
assert org_config["CCLW"]["total"] == 1
assert org_config["UNFCCC"]["total"] == 1
assert org_config["UNFCCC"]["count_by_category"] == {
"Executive": 0,
"Legislative": 1,
"MCF": 0,
"UNFCCC": 0,
}
assert org_config["CCLW"]["count_by_category"] == {
"Executive": 1,
"Legislative": 0,
"MCF": 0,
"UNFCCC": 0,
}


class _MockColumn:
def __init__(self, name):
self.name = name
Expand Down
Loading