From dcde52817e0979213407ac65feb9f86d22255256 Mon Sep 17 00:00:00 2001 From: Kegan Maher Date: Fri, 3 Jan 2025 01:53:50 +0000 Subject: [PATCH] refactor(models): split into multiple files Take all models defined in benefits/core/models.py and split into files to group models by domain: * benefits/core/models/common.py: common fields/models, helper functions * benefits/core/models/claims.py: ClaimProvider model * benefits/core/models/enrollment.py: EnrollmentFlow, EnrollmentEvent models * benefits/core/models/transit.py: TransitProvider, TransitAgency models Maintain existing imports via top-level benefits/core/models/__init__.py --- .../core/migrations/0033_pemdata_helptext.py | 36 ++ benefits/core/models/__init__.py | 18 + benefits/core/models/claims.py | 29 ++ benefits/core/models/common.py | 91 +++++ .../core/{models.py => models/enrollment.py} | 373 +----------------- benefits/core/models/transit.py | 266 +++++++++++++ tests/pytest/conftest.py | 12 +- tests/pytest/core/models/__init__.py | 0 tests/pytest/core/models/test_claims.py | 20 + tests/pytest/core/models/test_common.py | 112 ++++++ .../test_enrollment.py} | 344 +--------------- tests/pytest/core/models/test_transit.py | 206 ++++++++++ 12 files changed, 796 insertions(+), 711 deletions(-) create mode 100644 benefits/core/migrations/0033_pemdata_helptext.py create mode 100644 benefits/core/models/__init__.py create mode 100644 benefits/core/models/claims.py create mode 100644 benefits/core/models/common.py rename benefits/core/{models.py => models/enrollment.py} (51%) create mode 100644 benefits/core/models/transit.py create mode 100644 tests/pytest/core/models/__init__.py create mode 100644 tests/pytest/core/models/test_claims.py create mode 100644 tests/pytest/core/models/test_common.py rename tests/pytest/core/{test_models.py => models/test_enrollment.py} (56%) create mode 100644 tests/pytest/core/models/test_transit.py diff --git a/benefits/core/migrations/0033_pemdata_helptext.py b/benefits/core/migrations/0033_pemdata_helptext.py new file mode 100644 index 0000000000..09bfb9bb31 --- /dev/null +++ b/benefits/core/migrations/0033_pemdata_helptext.py @@ -0,0 +1,36 @@ +# Generated by Django 5.1.4 on 2025-01-03 01:59 + +import benefits.core.models.common +import benefits.secrets +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("core", "0032_optionalfields"), + ] + + operations = [ + migrations.AlterField( + model_name="pemdata", + name="label", + field=models.TextField(help_text="Human description of the PEM data"), + ), + migrations.AlterField( + model_name="pemdata", + name="remote_url", + field=models.TextField(blank=True, default="", help_text="Public URL hosting the utf-8 encoded PEM text"), + ), + migrations.AlterField( + model_name="pemdata", + name="text_secret_name", + field=benefits.core.models.common.SecretNameField( + blank=True, + default="", + help_text="The name of a secret with data in utf-8 encoded PEM text format", + max_length=127, + validators=[benefits.secrets.SecretNameValidator()], + ), + ), + ] diff --git a/benefits/core/models/__init__.py b/benefits/core/models/__init__.py new file mode 100644 index 0000000000..ddc3f32c89 --- /dev/null +++ b/benefits/core/models/__init__.py @@ -0,0 +1,18 @@ +from .common import template_path, SecretNameField, PemData +from .claims import ClaimsProvider +from .transit import agency_logo_large, agency_logo_small, TransitProcessor, TransitAgency +from .enrollment import EnrollmentMethods, EnrollmentFlow, EnrollmentEvent + +__all__ = [ + "template_path", + "SecretNameField", + "PemData", + "ClaimsProvider", + "agency_logo_large", + "agency_logo_small", + "TransitProcessor", + "TransitAgency", + "EnrollmentMethods", + "EnrollmentFlow", + "EnrollmentEvent", +] diff --git a/benefits/core/models/claims.py b/benefits/core/models/claims.py new file mode 100644 index 0000000000..8e34b200b0 --- /dev/null +++ b/benefits/core/models/claims.py @@ -0,0 +1,29 @@ +from django.db import models + +from .common import SecretNameField + + +class ClaimsProvider(models.Model): + """An entity that provides claims for eligibility verification.""" + + id = models.AutoField(primary_key=True) + sign_out_button_template = models.TextField(default="", blank=True, help_text="Template that renders sign-out button") + sign_out_link_template = models.TextField(default="", blank=True, help_text="Template that renders sign-out link") + client_name = models.TextField(help_text="Unique identifier used to register this claims provider with Authlib registry") + client_id_secret_name = SecretNameField( + help_text="The name of the secret containing the client ID for this claims provider" + ) + authority = models.TextField(help_text="The fully qualified HTTPS domain name for an OAuth authority server") + scheme = models.TextField(help_text="The authentication scheme to use") + + @property + def supports_sign_out(self): + return bool(self.sign_out_button_template) or bool(self.sign_out_link_template) + + @property + def client_id(self): + secret_name_field = self._meta.get_field("client_id_secret_name") + return secret_name_field.secret_value(self) + + def __str__(self) -> str: + return self.client_name diff --git a/benefits/core/models/common.py b/benefits/core/models/common.py new file mode 100644 index 0000000000..ee39c77651 --- /dev/null +++ b/benefits/core/models/common.py @@ -0,0 +1,91 @@ +from functools import cached_property +import logging +from pathlib import Path + +from django import template +from django.conf import settings +from django.db import models + +import requests + +from benefits.secrets import NAME_VALIDATOR, get_secret_by_name + +logger = logging.getLogger(__name__) + + +def template_path(template_name: str) -> Path: + """Get a `pathlib.Path` for the named template, or None if it can't be found. + + A `template_name` is the app-local name, e.g. `enrollment/success.html`. + + Adapted from https://stackoverflow.com/a/75863472. + """ + if template_name: + for engine in template.engines.all(): + for loader in engine.engine.template_loaders: + for origin in loader.get_template_sources(template_name): + path = Path(origin.name) + if path.exists() and path.is_file(): + return path + return None + + +class SecretNameField(models.SlugField): + """Field that stores the name of a secret held in a secret store. + + The secret value itself MUST NEVER be stored in this field. + """ + + description = """Field that stores the name of a secret held in a secret store. + + Secret names must be between 1-127 alphanumeric ASCII characters or hyphen characters. + + The secret value itself MUST NEVER be stored in this field. + """ + + def __init__(self, *args, **kwargs): + kwargs["validators"] = [NAME_VALIDATOR] + # although the validator also checks for a max length of 127 + # this setting enforces the length at the database column level as well + kwargs["max_length"] = 127 + # the default is False, but this is more explicit + kwargs["allow_unicode"] = False + super().__init__(*args, **kwargs) + + def secret_value(self, instance): + """Get the secret value from the secret store.""" + secret_name = getattr(instance, self.attname) + return get_secret_by_name(secret_name) + + +class PemData(models.Model): + """API Certificate or Key in PEM format.""" + + id = models.AutoField(primary_key=True) + label = models.TextField(help_text="Human description of the PEM data") + text_secret_name = SecretNameField( + default="", blank=True, help_text="The name of a secret with data in utf-8 encoded PEM text format" + ) + remote_url = models.TextField(default="", blank=True, help_text="Public URL hosting the utf-8 encoded PEM text") + + def __str__(self): + return self.label + + @cached_property + def data(self): + """ + Attempts to get data from `remote_url` or `text_secret_name`, with the latter taking precendence if both are defined. + """ + remote_data = None + secret_data = None + + if self.remote_url: + remote_data = requests.get(self.remote_url, timeout=settings.REQUESTS_TIMEOUT).text + if self.text_secret_name: + try: + secret_field = self._meta.get_field("text_secret_name") + secret_data = secret_field.secret_value(self) + except Exception: + secret_data = None + + return secret_data if secret_data is not None else remote_data diff --git a/benefits/core/models.py b/benefits/core/models/enrollment.py similarity index 51% rename from benefits/core/models.py rename to benefits/core/models/enrollment.py index f90f216396..3ba748afe4 100644 --- a/benefits/core/models.py +++ b/benefits/core/models/enrollment.py @@ -1,383 +1,19 @@ -""" -The core application: Common model definitions. -""" - -from functools import cached_property import importlib import logging -import os -from pathlib import Path import uuid -from django import template -from django.conf import settings from django.core.exceptions import ValidationError -from django.contrib.auth.models import Group, User from django.db import models -from django.urls import reverse from django.utils import timezone - -import requests - -from benefits.routes import routes -from benefits.secrets import NAME_VALIDATOR, get_secret_by_name from multiselectfield import MultiSelectField +from .common import PemData, SecretNameField, template_path +from .claims import ClaimsProvider +from .transit import TransitAgency logger = logging.getLogger(__name__) -def template_path(template_name: str) -> Path: - """Get a `pathlib.Path` for the named template, or None if it can't be found. - - A `template_name` is the app-local name, e.g. `enrollment/success.html`. - - Adapted from https://stackoverflow.com/a/75863472. - """ - if template_name: - for engine in template.engines.all(): - for loader in engine.engine.template_loaders: - for origin in loader.get_template_sources(template_name): - path = Path(origin.name) - if path.exists() and path.is_file(): - return path - return None - - -class SecretNameField(models.SlugField): - """Field that stores the name of a secret held in a secret store. - - The secret value itself MUST NEVER be stored in this field. - """ - - description = """Field that stores the name of a secret held in a secret store. - - Secret names must be between 1-127 alphanumeric ASCII characters or hyphen characters. - - The secret value itself MUST NEVER be stored in this field. - """ - - def __init__(self, *args, **kwargs): - kwargs["validators"] = [NAME_VALIDATOR] - # although the validator also checks for a max length of 127 - # this setting enforces the length at the database column level as well - kwargs["max_length"] = 127 - # the default is False, but this is more explicit - kwargs["allow_unicode"] = False - super().__init__(*args, **kwargs) - - -class PemData(models.Model): - """API Certificate or Key in PEM format.""" - - id = models.AutoField(primary_key=True) - # Human description of the PEM data - label = models.TextField() - # The name of a secret with data in utf-8 encoded PEM text format - text_secret_name = SecretNameField(default="", blank=True) - # Public URL hosting the utf-8 encoded PEM text - remote_url = models.TextField(default="", blank=True) - - def __str__(self): - return self.label - - @cached_property - def data(self): - """ - Attempts to get data from `remote_url` or `text_secret_name`, with the latter taking precendence if both are defined. - """ - remote_data = None - secret_data = None - - if self.remote_url: - remote_data = requests.get(self.remote_url, timeout=settings.REQUESTS_TIMEOUT).text - if self.text_secret_name: - try: - secret_data = get_secret_by_name(self.text_secret_name) - except Exception: - secret_data = None - - return secret_data if secret_data is not None else remote_data - - -class ClaimsProvider(models.Model): - """An entity that provides claims for eligibility verification.""" - - id = models.AutoField(primary_key=True) - sign_out_button_template = models.TextField(default="", blank=True, help_text="Template that renders sign-out button") - sign_out_link_template = models.TextField(default="", blank=True, help_text="Template that renders sign-out link") - client_name = models.TextField(help_text="Unique identifier used to register this claims provider with Authlib registry") - client_id_secret_name = SecretNameField( - help_text="The name of the secret containing the client ID for this claims provider" - ) - authority = models.TextField(help_text="The fully qualified HTTPS domain name for an OAuth authority server") - scheme = models.TextField(help_text="The authentication scheme to use") - - @property - def supports_sign_out(self): - return bool(self.sign_out_button_template) or bool(self.sign_out_link_template) - - @property - def client_id(self): - return get_secret_by_name(self.client_id_secret_name) - - def __str__(self) -> str: - return self.client_name - - -class TransitProcessor(models.Model): - """An entity that applies transit agency fare rules to rider transactions.""" - - id = models.AutoField(primary_key=True) - name = models.TextField(help_text="Primary internal display name for this TransitProcessor instance, e.g. in the Admin.") - api_base_url = models.TextField(help_text="The absolute base URL for the TransitProcessor's API, including https://.") - card_tokenize_url = models.TextField( - help_text="The absolute URL for the client-side card tokenization library provided by the TransitProcessor." - ) - card_tokenize_func = models.TextField( - help_text="The function from the card tokenization library to call on the client to initiate the process." - ) - card_tokenize_env = models.TextField(help_text="The environment in which card tokenization is occurring.") - portal_url = models.TextField( - default="", - blank=True, - help_text="The absolute base URL for the TransitProcessor's control portal, including https://.", - ) - - def __str__(self): - return self.name - - -def _agency_logo(instance, filename, size): - base, ext = os.path.splitext(filename) - return f"agencies/{instance.slug}-{size}" + ext - - -def agency_logo_small(instance, filename): - return _agency_logo(instance, filename, "sm") - - -def agency_logo_large(instance, filename): - return _agency_logo(instance, filename, "lg") - - -class TransitAgency(models.Model): - """An agency offering transit service.""" - - id = models.AutoField(primary_key=True) - active = models.BooleanField(default=False, help_text="Determines if this Agency is enabled for users") - slug = models.SlugField(help_text="Used for URL navigation for this agency, e.g. the agency homepage url is /{slug}") - short_name = models.TextField( - default="", blank=True, help_text="The user-facing short name for this agency. Often an uppercase acronym." - ) - long_name = models.TextField( - default="", - blank=True, - help_text="The user-facing long name for this agency. Often the short_name acronym, spelled out.", - ) - info_url = models.URLField( - default="", - blank=True, - help_text="URL of a website/page with more information about the agency's discounts", - ) - phone = models.TextField(default="", blank=True, help_text="Agency customer support phone number") - index_template_override = models.TextField( - help_text="Override the default template used for this agency's landing page", - blank=True, - default="", - ) - eligibility_index_template_override = models.TextField( - help_text="Override the default template used for this agency's eligibility landing page", - blank=True, - default="", - ) - eligibility_api_id = models.TextField( - help_text="The identifier for this agency used in Eligibility API calls.", - blank=True, - default="", - ) - eligibility_api_private_key = models.ForeignKey( - PemData, - related_name="+", - on_delete=models.PROTECT, - help_text="Private key used to sign Eligibility API tokens created on behalf of this Agency.", - null=True, - blank=True, - default=None, - ) - eligibility_api_public_key = models.ForeignKey( - PemData, - related_name="+", - on_delete=models.PROTECT, - help_text="Public key corresponding to the agency's private key, used by Eligibility Verification servers to encrypt responses.", # noqa: E501 - null=True, - blank=True, - default=None, - ) - transit_processor = models.ForeignKey( - TransitProcessor, - on_delete=models.PROTECT, - null=True, - blank=True, - default=None, - help_text="This agency's TransitProcessor.", - ) - transit_processor_audience = models.TextField( - help_text="This agency's audience value used to access the TransitProcessor's API.", default="", blank=True - ) - transit_processor_client_id = models.TextField( - help_text="This agency's client_id value used to access the TransitProcessor's API.", default="", blank=True - ) - transit_processor_client_secret_name = SecretNameField( - help_text="The name of the secret containing this agency's client_secret value used to access the TransitProcessor's API.", # noqa: E501 - default="", - blank=True, - ) - staff_group = models.OneToOneField( - Group, - on_delete=models.PROTECT, - null=True, - blank=True, - default=None, - help_text="The group of users associated with this TransitAgency.", - related_name="transit_agency", - ) - sso_domain = models.TextField( - blank=True, - default="", - help_text="The email domain of users to automatically add to this agency's staff group upon login.", - ) - customer_service_group = models.OneToOneField( - Group, - on_delete=models.PROTECT, - null=True, - blank=True, - default=None, - help_text="The group of users who are allowed to do in-person eligibility verification and enrollment.", - related_name="+", - ) - logo_large = models.ImageField( - default="", - blank=True, - upload_to=agency_logo_large, - help_text="The large version of the transit agency's logo.", - ) - logo_small = models.ImageField( - default="", - blank=True, - upload_to=agency_logo_small, - help_text="The small version of the transit agency's logo.", - ) - - def __str__(self): - return self.long_name - - @property - def index_template(self): - return self.index_template_override or f"core/index--{self.slug}.html" - - @property - def index_url(self): - """Public-facing URL to the TransitAgency's landing page.""" - return reverse(routes.AGENCY_INDEX, args=[self.slug]) - - @property - def eligibility_index_template(self): - return self.eligibility_index_template_override or f"eligibility/index--{self.slug}.html" - - @property - def eligibility_index_url(self): - """Public facing URL to the TransitAgency's eligibility page.""" - return reverse(routes.AGENCY_ELIGIBILITY_INDEX, args=[self.slug]) - - @property - def eligibility_api_private_key_data(self): - """This Agency's private key as a string.""" - return self.eligibility_api_private_key.data - - @property - def eligibility_api_public_key_data(self): - """This Agency's public key as a string.""" - return self.eligibility_api_public_key.data - - @property - def transit_processor_client_secret(self): - return get_secret_by_name(self.transit_processor_client_secret_name) - - @property - def enrollment_flows(self): - return self.enrollmentflow_set - - def clean(self): - field_errors = {} - template_errors = [] - - if self.active: - for flow in self.enrollment_flows.all(): - try: - flow.clean() - except ValidationError: - raise ValidationError(f"Invalid EnrollmentFlow: {flow.label}") - - message = "This field is required for active transit agencies." - needed = dict( - short_name=self.short_name, - long_name=self.long_name, - phone=self.phone, - info_url=self.info_url, - logo_large=self.logo_large, - logo_small=self.logo_small, - ) - if self.transit_processor: - needed.update( - dict( - transit_processor_audience=self.transit_processor_audience, - transit_processor_client_id=self.transit_processor_client_id, - transit_processor_client_secret_name=self.transit_processor_client_secret_name, - ) - ) - field_errors.update({k: ValidationError(message) for k, v in needed.items() if not v}) - - # since templates are calculated from the pattern or the override field - # we can't add a field-level validation error - # so just create directly for a missing template - for t in [self.index_template, self.eligibility_index_template]: - if not template_path(t): - template_errors.append(ValidationError(f"Template not found: {t}")) - - if field_errors: - raise ValidationError(field_errors) - if template_errors: - raise ValidationError(template_errors) - - @staticmethod - def by_id(id): - """Get a TransitAgency instance by its ID.""" - logger.debug(f"Get {TransitAgency.__name__} by id: {id}") - return TransitAgency.objects.get(id=id) - - @staticmethod - def by_slug(slug): - """Get a TransitAgency instance by its slug.""" - logger.debug(f"Get {TransitAgency.__name__} by slug: {slug}") - return TransitAgency.objects.filter(slug=slug).first() - - @staticmethod - def all_active(): - """Get all TransitAgency instances marked active.""" - logger.debug(f"Get all active {TransitAgency.__name__}") - return TransitAgency.objects.filter(active=True) - - @staticmethod - def for_user(user: User): - for group in user.groups.all(): - if hasattr(group, "transit_agency"): - return group.transit_agency # this is looking at the TransitAgency's staff_group - - # the loop above returns the first match found. Return None if no match was found. - return None - - class EnrollmentMethods: DIGITAL = "digital" IN_PERSON = "in_person" @@ -537,7 +173,8 @@ def agency_card_name(self): @property def eligibility_api_auth_key(self): if self.eligibility_api_auth_key_secret_name is not None: - return get_secret_by_name(self.eligibility_api_auth_key_secret_name) + secret_field = self._meta.get_field("eligibility_api_auth_key_secret_name") + return secret_field.secret_value(self) else: return None diff --git a/benefits/core/models/transit.py b/benefits/core/models/transit.py new file mode 100644 index 0000000000..0cf780080f --- /dev/null +++ b/benefits/core/models/transit.py @@ -0,0 +1,266 @@ +import os +import logging + +from django.core.exceptions import ValidationError +from django.contrib.auth.models import Group, User +from django.db import models +from django.urls import reverse + +from benefits.routes import routes +from .common import PemData, SecretNameField, template_path + +logger = logging.getLogger(__name__) + + +def _agency_logo(instance, filename, size): + base, ext = os.path.splitext(filename) + return f"agencies/{instance.slug}-{size}" + ext + + +def agency_logo_small(instance, filename): + return _agency_logo(instance, filename, "sm") + + +def agency_logo_large(instance, filename): + return _agency_logo(instance, filename, "lg") + + +class TransitProcessor(models.Model): + """An entity that applies transit agency fare rules to rider transactions.""" + + id = models.AutoField(primary_key=True) + name = models.TextField(help_text="Primary internal display name for this TransitProcessor instance, e.g. in the Admin.") + api_base_url = models.TextField(help_text="The absolute base URL for the TransitProcessor's API, including https://.") + card_tokenize_url = models.TextField( + help_text="The absolute URL for the client-side card tokenization library provided by the TransitProcessor." + ) + card_tokenize_func = models.TextField( + help_text="The function from the card tokenization library to call on the client to initiate the process." + ) + card_tokenize_env = models.TextField(help_text="The environment in which card tokenization is occurring.") + portal_url = models.TextField( + default="", + blank=True, + help_text="The absolute base URL for the TransitProcessor's control portal, including https://.", + ) + + def __str__(self): + return self.name + + +class TransitAgency(models.Model): + """An agency offering transit service.""" + + id = models.AutoField(primary_key=True) + active = models.BooleanField(default=False, help_text="Determines if this Agency is enabled for users") + slug = models.SlugField(help_text="Used for URL navigation for this agency, e.g. the agency homepage url is /{slug}") + short_name = models.TextField( + default="", blank=True, help_text="The user-facing short name for this agency. Often an uppercase acronym." + ) + long_name = models.TextField( + default="", + blank=True, + help_text="The user-facing long name for this agency. Often the short_name acronym, spelled out.", + ) + info_url = models.URLField( + default="", + blank=True, + help_text="URL of a website/page with more information about the agency's discounts", + ) + phone = models.TextField(default="", blank=True, help_text="Agency customer support phone number") + index_template_override = models.TextField( + help_text="Override the default template used for this agency's landing page", + blank=True, + default="", + ) + eligibility_index_template_override = models.TextField( + help_text="Override the default template used for this agency's eligibility landing page", + blank=True, + default="", + ) + eligibility_api_id = models.TextField( + help_text="The identifier for this agency used in Eligibility API calls.", + blank=True, + default="", + ) + eligibility_api_private_key = models.ForeignKey( + PemData, + related_name="+", + on_delete=models.PROTECT, + help_text="Private key used to sign Eligibility API tokens created on behalf of this Agency.", + null=True, + blank=True, + default=None, + ) + eligibility_api_public_key = models.ForeignKey( + PemData, + related_name="+", + on_delete=models.PROTECT, + help_text="Public key corresponding to the agency's private key, used by Eligibility Verification servers to encrypt responses.", # noqa: E501 + null=True, + blank=True, + default=None, + ) + transit_processor = models.ForeignKey( + TransitProcessor, + on_delete=models.PROTECT, + null=True, + blank=True, + default=None, + help_text="This agency's TransitProcessor.", + ) + transit_processor_audience = models.TextField( + help_text="This agency's audience value used to access the TransitProcessor's API.", default="", blank=True + ) + transit_processor_client_id = models.TextField( + help_text="This agency's client_id value used to access the TransitProcessor's API.", default="", blank=True + ) + transit_processor_client_secret_name = SecretNameField( + help_text="The name of the secret containing this agency's client_secret value used to access the TransitProcessor's API.", # noqa: E501 + default="", + blank=True, + ) + staff_group = models.OneToOneField( + Group, + on_delete=models.PROTECT, + null=True, + blank=True, + default=None, + help_text="The group of users associated with this TransitAgency.", + related_name="transit_agency", + ) + sso_domain = models.TextField( + blank=True, + default="", + help_text="The email domain of users to automatically add to this agency's staff group upon login.", + ) + customer_service_group = models.OneToOneField( + Group, + on_delete=models.PROTECT, + null=True, + blank=True, + default=None, + help_text="The group of users who are allowed to do in-person eligibility verification and enrollment.", + related_name="+", + ) + logo_large = models.ImageField( + default="", + blank=True, + upload_to=agency_logo_large, + help_text="The large version of the transit agency's logo.", + ) + logo_small = models.ImageField( + default="", + blank=True, + upload_to=agency_logo_small, + help_text="The small version of the transit agency's logo.", + ) + + def __str__(self): + return self.long_name + + @property + def index_template(self): + return self.index_template_override or f"core/index--{self.slug}.html" + + @property + def index_url(self): + """Public-facing URL to the TransitAgency's landing page.""" + return reverse(routes.AGENCY_INDEX, args=[self.slug]) + + @property + def eligibility_index_template(self): + return self.eligibility_index_template_override or f"eligibility/index--{self.slug}.html" + + @property + def eligibility_index_url(self): + """Public facing URL to the TransitAgency's eligibility page.""" + return reverse(routes.AGENCY_ELIGIBILITY_INDEX, args=[self.slug]) + + @property + def eligibility_api_private_key_data(self): + """This Agency's private key as a string.""" + return self.eligibility_api_private_key.data + + @property + def eligibility_api_public_key_data(self): + """This Agency's public key as a string.""" + return self.eligibility_api_public_key.data + + @property + def transit_processor_client_secret(self): + secret_field = self._meta.get_field("transit_processor_client_secret_name") + return secret_field.secret_value(self) + + @property + def enrollment_flows(self): + return self.enrollmentflow_set + + def clean(self): + field_errors = {} + template_errors = [] + + if self.active: + for flow in self.enrollment_flows.all(): + try: + flow.clean() + except ValidationError: + raise ValidationError(f"Invalid EnrollmentFlow: {flow.label}") + + message = "This field is required for active transit agencies." + needed = dict( + short_name=self.short_name, + long_name=self.long_name, + phone=self.phone, + info_url=self.info_url, + logo_large=self.logo_large, + logo_small=self.logo_small, + ) + if self.transit_processor: + needed.update( + dict( + transit_processor_audience=self.transit_processor_audience, + transit_processor_client_id=self.transit_processor_client_id, + transit_processor_client_secret_name=self.transit_processor_client_secret_name, + ) + ) + field_errors.update({k: ValidationError(message) for k, v in needed.items() if not v}) + + # since templates are calculated from the pattern or the override field + # we can't add a field-level validation error + # so just create directly for a missing template + for t in [self.index_template, self.eligibility_index_template]: + if not template_path(t): + template_errors.append(ValidationError(f"Template not found: {t}")) + + if field_errors: + raise ValidationError(field_errors) + if template_errors: + raise ValidationError(template_errors) + + @staticmethod + def by_id(id): + """Get a TransitAgency instance by its ID.""" + logger.debug(f"Get {TransitAgency.__name__} by id: {id}") + return TransitAgency.objects.get(id=id) + + @staticmethod + def by_slug(slug): + """Get a TransitAgency instance by its slug.""" + logger.debug(f"Get {TransitAgency.__name__} by slug: {slug}") + return TransitAgency.objects.filter(slug=slug).first() + + @staticmethod + def all_active(): + """Get all TransitAgency instances marked active.""" + logger.debug(f"Get all active {TransitAgency.__name__}") + return TransitAgency.objects.filter(active=True) + + @staticmethod + def for_user(user: User): + for group in user.groups.all(): + if hasattr(group, "transit_agency"): + return group.transit_agency # this is looking at the TransitAgency's staff_group + + # the loop above returns the first match found. Return None if no match was found. + return None diff --git a/tests/pytest/conftest.py b/tests/pytest/conftest.py index 620622ba4f..28a694f2d5 100644 --- a/tests/pytest/conftest.py +++ b/tests/pytest/conftest.py @@ -9,7 +9,13 @@ from pytest_socket import disable_socket from benefits.core import session -from benefits.core.models import ClaimsProvider, EnrollmentFlow, TransitProcessor, PemData, TransitAgency +from benefits.core.models import ( + ClaimsProvider, + EnrollmentFlow, + TransitProcessor, + PemData, + TransitAgency, +) def pytest_runtest_setup(): @@ -42,8 +48,8 @@ def model_User(): # autouse this fixture so we never call out to the real secret store @pytest.fixture(autouse=True) -def mock_models_get_secret_by_name(mocker): - return mocker.patch("benefits.core.models.get_secret_by_name", return_value="secret value!") +def mock_get_secret_by_name(mocker): + return mocker.patch("benefits.core.models.common.get_secret_by_name", return_value="secret value!") @pytest.fixture diff --git a/tests/pytest/core/models/__init__.py b/tests/pytest/core/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pytest/core/models/test_claims.py b/tests/pytest/core/models/test_claims.py new file mode 100644 index 0000000000..057b59449f --- /dev/null +++ b/tests/pytest/core/models/test_claims.py @@ -0,0 +1,20 @@ +import pytest + + +@pytest.mark.django_db +def test_model_ClaimsProvider(model_ClaimsProvider): + assert model_ClaimsProvider.supports_sign_out + assert str(model_ClaimsProvider) == model_ClaimsProvider.client_name + + +@pytest.mark.django_db +def test_model_ClaimsProvider_client_id(model_ClaimsProvider, mock_get_secret_by_name): + secret_value = model_ClaimsProvider.client_id + + mock_get_secret_by_name.assert_called_once_with(model_ClaimsProvider.client_id_secret_name) + assert secret_value == mock_get_secret_by_name.return_value + + +@pytest.mark.django_db +def test_model_ClaimsProvider_no_sign_out(model_ClaimsProvider_no_sign_out): + assert not model_ClaimsProvider_no_sign_out.supports_sign_out diff --git a/tests/pytest/core/models/test_common.py b/tests/pytest/core/models/test_common.py new file mode 100644 index 0000000000..0419fc3b9b --- /dev/null +++ b/tests/pytest/core/models/test_common.py @@ -0,0 +1,112 @@ +from pathlib import Path + +from django.conf import settings + +import pytest + +from benefits.core.models import template_path, SecretNameField +import benefits.secrets + + +@pytest.fixture +def mock_requests_get_pem_data(mocker): + # intercept and spy on the GET request + return mocker.patch("benefits.core.models.common.requests.get", return_value=mocker.Mock(text="PEM text")) + + +@pytest.mark.django_db +@pytest.mark.parametrize( + "input_template,expected_path", + [ + ("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"), + ("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"), + ("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"), + ("", None), + ("nope.html", None), + ("core/not-there.html", None), + ], +) +def test_template_path(input_template, expected_path): + if expected_path: + assert template_path(input_template) == Path(expected_path) + else: + assert template_path(input_template) is None + + +def test_SecretNameField_init(): + field = SecretNameField() + + assert benefits.secrets.NAME_VALIDATOR in field.validators + assert field.max_length == 127 + assert field.blank is False + assert field.null is False + assert field.allow_unicode is False + assert field.description is not None + assert field.description != "" + + +def test_SecretNameField_init_null_blank(): + field = SecretNameField(blank=True, null=True) + + assert field.blank is True + assert field.null is True + + +@pytest.mark.django_db +def test_PemData_str(model_PemData): + assert str(model_PemData) == model_PemData.label + + +@pytest.mark.django_db +def test_PemData_data_text_secret_name(model_PemData, mock_get_secret_by_name): + # a secret name and not remote URL, should use secret value + + data = model_PemData.data + + mock_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) + assert data == mock_get_secret_by_name.return_value + + +@pytest.mark.django_db +def test_PemData_data_remote(model_PemData, mock_requests_get_pem_data): + # a remote URL and no secret name, should use remote value + + model_PemData.text_secret_name = None + model_PemData.remote_url = "http://localhost/publickey" + + assert not model_PemData.text_secret_name + + data = model_PemData.data + + mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) + assert data == mock_requests_get_pem_data.return_value.text + + +@pytest.mark.django_db +def test_PemData_data_text_secret_name_and_remote__uses_text_secret( + model_PemData, mock_get_secret_by_name, mock_requests_get_pem_data +): + # a remote URL and the secret value is not None, should use the secret value + + model_PemData.remote_url = "http://localhost/publickey" + + data = model_PemData.data + + mock_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) + mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) + assert data == mock_get_secret_by_name.return_value + + +@pytest.mark.django_db +def test_PemData_data_text_secret_name_and_remote__uses_remote( + model_PemData, mock_get_secret_by_name, mock_requests_get_pem_data +): + # a remote URL and the secret value is None, should use remote value + model_PemData.remote_url = "http://localhost/publickey" + mock_get_secret_by_name.return_value = None + + data = model_PemData.data + + mock_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) + mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) + assert data == mock_requests_get_pem_data.return_value.text diff --git a/tests/pytest/core/test_models.py b/tests/pytest/core/models/test_enrollment.py similarity index 56% rename from tests/pytest/core/test_models.py rename to tests/pytest/core/models/test_enrollment.py index afbb4f0b3f..9ab56e41dd 100644 --- a/tests/pytest/core/test_models.py +++ b/tests/pytest/core/models/test_enrollment.py @@ -1,147 +1,11 @@ from datetime import timedelta -from pathlib import Path -from django.conf import settings -from django.contrib.auth.models import Group, User from django.core.exceptions import ValidationError from django.utils import timezone import pytest -from benefits.core.models import ( - template_path, - SecretNameField, - EnrollmentFlow, - TransitAgency, - EnrollmentEvent, - EnrollmentMethods, - agency_logo_small, - agency_logo_large, -) -import benefits.secrets - - -@pytest.fixture -def mock_requests_get_pem_data(mocker): - # intercept and spy on the GET request - return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text")) - - -@pytest.mark.django_db -@pytest.mark.parametrize( - "input_template,expected_path", - [ - ("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"), - ("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"), - ("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"), - ("", None), - ("nope.html", None), - ("core/not-there.html", None), - ], -) -def test_template_path(input_template, expected_path): - if expected_path: - assert template_path(input_template) == Path(expected_path) - else: - assert template_path(input_template) is None - - -def test_SecretNameField_init(): - field = SecretNameField() - - assert benefits.secrets.NAME_VALIDATOR in field.validators - assert field.max_length == 127 - assert field.blank is False - assert field.null is False - assert field.allow_unicode is False - assert field.description is not None - assert field.description != "" - - -def test_SecretNameField_init_null_blank(): - field = SecretNameField(blank=True, null=True) - - assert field.blank is True - assert field.null is True - - -@pytest.mark.django_db -def test_PemData_str(model_PemData): - assert str(model_PemData) == model_PemData.label - - -@pytest.mark.django_db -def test_PemData_data_text_secret_name(model_PemData, mock_models_get_secret_by_name): - # a secret name and not remote URL, should use secret value - - data = model_PemData.data - - mock_models_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) - assert data == mock_models_get_secret_by_name.return_value - - -@pytest.mark.django_db -def test_PemData_data_remote(model_PemData, mock_requests_get_pem_data): - # a remote URL and no secret name, should use remote value - - model_PemData.text_secret_name = None - model_PemData.remote_url = "http://localhost/publickey" - - assert not model_PemData.text_secret_name - - data = model_PemData.data - - mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) - assert data == mock_requests_get_pem_data.return_value.text - - -@pytest.mark.django_db -def test_PemData_data_text_secret_name_and_remote__uses_text_secret( - model_PemData, mock_models_get_secret_by_name, mock_requests_get_pem_data -): - # a remote URL and the secret value is not None, should use the secret value - - model_PemData.remote_url = "http://localhost/publickey" - - data = model_PemData.data - - mock_models_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) - mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) - assert data == mock_models_get_secret_by_name.return_value - - -@pytest.mark.django_db -def test_PemData_data_text_secret_name_and_remote__uses_remote( - model_PemData, mock_models_get_secret_by_name, mock_requests_get_pem_data -): - # a remote URL and the secret value is None, should use remote value - model_PemData.remote_url = "http://localhost/publickey" - mock_models_get_secret_by_name.return_value = None - - data = model_PemData.data - - mock_models_get_secret_by_name.assert_called_once_with(model_PemData.text_secret_name) - mock_requests_get_pem_data.assert_called_once_with(model_PemData.remote_url, timeout=settings.REQUESTS_TIMEOUT) - assert data == mock_requests_get_pem_data.return_value.text - - -@pytest.mark.django_db -def test_model_ClaimsProvider(model_ClaimsProvider): - assert model_ClaimsProvider.supports_sign_out - assert str(model_ClaimsProvider) == model_ClaimsProvider.client_name - - -@pytest.mark.django_db -def test_model_ClaimsProvider_client_id(model_ClaimsProvider, mock_models_get_secret_by_name): - secret_value = model_ClaimsProvider.client_id - - mock_models_get_secret_by_name.assert_called_once_with(model_ClaimsProvider.client_id_secret_name) - assert secret_value == mock_models_get_secret_by_name.return_value - - -@pytest.mark.django_db -def test_model_ClaimsProvider_no_sign_out(model_ClaimsProvider_no_sign_out): - assert not model_ClaimsProvider_no_sign_out.supports_sign_out +from benefits.core.models import EnrollmentFlow, EnrollmentEvent, EnrollmentMethods @pytest.mark.django_db @@ -264,13 +128,13 @@ def test_EnrollmentFlow_no_scope_and_claim_no_sign_out(model_EnrollmentFlow, mod @pytest.mark.django_db -def test_EnrollmentFlow_eligibility_api_auth_key(model_EnrollmentFlow_with_eligibility_api, mock_models_get_secret_by_name): +def test_EnrollmentFlow_eligibility_api_auth_key(model_EnrollmentFlow_with_eligibility_api, mock_get_secret_by_name): secret_value = model_EnrollmentFlow_with_eligibility_api.eligibility_api_auth_key - mock_models_get_secret_by_name.assert_called_once_with( + mock_get_secret_by_name.assert_called_once_with( model_EnrollmentFlow_with_eligibility_api.eligibility_api_auth_key_secret_name ) - assert secret_value == mock_models_get_secret_by_name.return_value + assert secret_value == mock_get_secret_by_name.return_value @pytest.mark.django_db @@ -457,206 +321,6 @@ def test_EnrollmentFlow_clean_templates(model_EnrollmentFlow_with_scope_and_clai model_EnrollmentFlow_with_scope_and_claim.clean() -@pytest.mark.django_db -def test_TransitProcessor_str(model_TransitProcessor): - assert str(model_TransitProcessor) == model_TransitProcessor.name - - -@pytest.mark.django_db -def test_TransitAgency_defaults(): - agency = TransitAgency.objects.create(slug="test") - - assert agency.active is False - assert agency.slug == "test" - assert agency.short_name == "" - assert agency.long_name == "" - assert agency.phone == "" - assert agency.info_url == "" - assert agency.logo_large == "" - assert agency.logo_small == "" - # test fails if save fails - agency.save() - - -@pytest.mark.django_db -def test_TransitAgency_str(model_TransitAgency): - assert str(model_TransitAgency) == model_TransitAgency.long_name - - -@pytest.mark.django_db -def test_TransitAgency_template_overrides(model_TransitAgency): - assert model_TransitAgency.index_template == model_TransitAgency.index_template_override - assert model_TransitAgency.eligibility_index_template == model_TransitAgency.eligibility_index_template_override - - model_TransitAgency.index_template_override = "" - model_TransitAgency.eligibility_index_template_override = "" - model_TransitAgency.save() - - assert model_TransitAgency.index_template == f"core/index--{model_TransitAgency.slug}.html" - assert model_TransitAgency.eligibility_index_template == f"eligibility/index--{model_TransitAgency.slug}.html" - - -@pytest.mark.django_db -def test_TransitAgency_index_url(model_TransitAgency): - result = model_TransitAgency.index_url - - assert result.endswith(model_TransitAgency.slug) - - -@pytest.mark.django_db -def test_TransitAgency_by_id_matching(model_TransitAgency): - result = TransitAgency.by_id(model_TransitAgency.id) - - assert result == model_TransitAgency - - -@pytest.mark.django_db -def test_TransitAgency_by_id_nonmatching(): - with pytest.raises(TransitAgency.DoesNotExist): - TransitAgency.by_id(99999) - - -@pytest.mark.django_db -def test_TransitAgency_by_slug_matching(model_TransitAgency): - result = TransitAgency.by_slug(model_TransitAgency.slug) - - assert result == model_TransitAgency - - -@pytest.mark.django_db -def test_TransitAgency_by_slug_nonmatching(): - result = TransitAgency.by_slug("nope") - - assert not result - - -@pytest.mark.django_db -def test_TransitAgency_all_active(model_TransitAgency): - count = TransitAgency.objects.count() - assert count >= 1 - - inactive_agency = TransitAgency.by_id(model_TransitAgency.id) - inactive_agency.pk = None - inactive_agency.active = False - inactive_agency.save() - - assert TransitAgency.objects.count() == count + 1 - - result = TransitAgency.all_active() - - assert len(result) > 0 - assert model_TransitAgency in result - assert inactive_agency not in result - - -@pytest.mark.django_db -def test_TransitAgency_for_user_in_group(model_TransitAgency): - group = Group.objects.create(name="test_group") - - agency_for_user = TransitAgency.by_id(model_TransitAgency.id) - agency_for_user.pk = None - agency_for_user.staff_group = group - agency_for_user.save() - - user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) - user.groups.add(group) - - assert TransitAgency.for_user(user) == agency_for_user - - -@pytest.mark.django_db -def test_TransitAgency_for_user_not_in_group(model_TransitAgency): - group = Group.objects.create(name="test_group") - - agency_for_user = TransitAgency.by_id(model_TransitAgency.id) - agency_for_user.pk = None - agency_for_user.staff_group = group - agency_for_user.save() - - user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) - - assert TransitAgency.for_user(user) is None - - -@pytest.mark.django_db -def test_TransitAgency_for_user_in_group_not_linked_to_any_agency(): - group = Group.objects.create(name="another test group") - - user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) - user.groups.add(group) - - assert TransitAgency.for_user(user) is None - - -@pytest.mark.django_db -def test_agency_logo_small(model_TransitAgency): - assert agency_logo_small(model_TransitAgency, "local_filename.png") == "agencies/test-sm.png" - - -@pytest.mark.django_db -def test_agency_logo_large(model_TransitAgency): - assert agency_logo_large(model_TransitAgency, "local_filename.png") == "agencies/test-lg.png" - - -@pytest.mark.django_db -def test_TransitAgency_clean(model_TransitAgency_inactive, model_TransitProcessor): - model_TransitAgency_inactive.transit_processor = model_TransitProcessor - - model_TransitAgency_inactive.short_name = "" - model_TransitAgency_inactive.long_name = "" - model_TransitAgency_inactive.phone = "" - model_TransitAgency_inactive.info_url = "" - model_TransitAgency_inactive.logo_large = "" - model_TransitAgency_inactive.logo_small = "" - model_TransitAgency_inactive.transit_processor_audience = "" - model_TransitAgency_inactive.transit_processor_client_id = "" - model_TransitAgency_inactive.transit_processor_client_secret_name = "" - # agency is inactive, OK to have incomplete fields - model_TransitAgency_inactive.clean() - - # now mark it active and expect failure on clean() - model_TransitAgency_inactive.active = True - with pytest.raises(ValidationError) as e: - model_TransitAgency_inactive.clean() - - errors = e.value.error_dict - - assert "short_name" in errors - assert "long_name" in errors - assert "phone" in errors - assert "info_url" in errors - assert "logo_large" in errors - assert "logo_small" in errors - assert "transit_processor_audience" in errors - assert "transit_processor_client_id" in errors - assert "transit_processor_client_secret_name" in errors - - -@pytest.mark.django_db -@pytest.mark.parametrize("template_attribute", ["index_template_override", "eligibility_index_template_override"]) -def test_TransitAgency_clean_templates(model_TransitAgency_inactive, template_attribute): - setattr(model_TransitAgency_inactive, template_attribute, "does/not/exist.html") - # agency is inactive, OK to have missing template - model_TransitAgency_inactive.clean() - - # now mark it active and expect failure on clean() - model_TransitAgency_inactive.active = True - with pytest.raises(ValidationError, match="Template not found: does/not/exist.html"): - model_TransitAgency_inactive.clean() - - -@pytest.mark.django_db -def test_TransitAgency_clean_dirty_flow(model_TransitAgency, model_EnrollmentFlow, model_ClaimsProvider): - # partially setup the EnrollmentFlow - # missing scope and claims - model_EnrollmentFlow.claims_provider = model_ClaimsProvider - model_EnrollmentFlow.transit_agency = model_TransitAgency - - # clean the agency, and expect an invalid EnrollmentFlow error - with pytest.raises(ValidationError, match=f"Invalid EnrollmentFlow: {model_EnrollmentFlow.label}"): - model_TransitAgency.clean() - - @pytest.mark.django_db def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow): ts = timezone.now() diff --git a/tests/pytest/core/models/test_transit.py b/tests/pytest/core/models/test_transit.py new file mode 100644 index 0000000000..75114552de --- /dev/null +++ b/tests/pytest/core/models/test_transit.py @@ -0,0 +1,206 @@ +from django.contrib.auth.models import Group, User +from django.core.exceptions import ValidationError + +import pytest + +from benefits.core.models import TransitAgency, agency_logo_small, agency_logo_large + + +@pytest.mark.django_db +def test_TransitProcessor_str(model_TransitProcessor): + assert str(model_TransitProcessor) == model_TransitProcessor.name + + +@pytest.mark.django_db +def test_TransitAgency_defaults(): + agency = TransitAgency.objects.create(slug="test") + + assert agency.active is False + assert agency.slug == "test" + assert agency.short_name == "" + assert agency.long_name == "" + assert agency.phone == "" + assert agency.info_url == "" + assert agency.logo_large == "" + assert agency.logo_small == "" + # test fails if save fails + agency.save() + + +@pytest.mark.django_db +def test_TransitAgency_str(model_TransitAgency): + assert str(model_TransitAgency) == model_TransitAgency.long_name + + +@pytest.mark.django_db +def test_TransitAgency_template_overrides(model_TransitAgency): + assert model_TransitAgency.index_template == model_TransitAgency.index_template_override + assert model_TransitAgency.eligibility_index_template == model_TransitAgency.eligibility_index_template_override + + model_TransitAgency.index_template_override = "" + model_TransitAgency.eligibility_index_template_override = "" + model_TransitAgency.save() + + assert model_TransitAgency.index_template == f"core/index--{model_TransitAgency.slug}.html" + assert model_TransitAgency.eligibility_index_template == f"eligibility/index--{model_TransitAgency.slug}.html" + + +@pytest.mark.django_db +def test_TransitAgency_index_url(model_TransitAgency): + result = model_TransitAgency.index_url + + assert result.endswith(model_TransitAgency.slug) + + +@pytest.mark.django_db +def test_TransitAgency_by_id_matching(model_TransitAgency): + result = TransitAgency.by_id(model_TransitAgency.id) + + assert result == model_TransitAgency + + +@pytest.mark.django_db +def test_TransitAgency_by_id_nonmatching(): + with pytest.raises(TransitAgency.DoesNotExist): + TransitAgency.by_id(99999) + + +@pytest.mark.django_db +def test_TransitAgency_by_slug_matching(model_TransitAgency): + result = TransitAgency.by_slug(model_TransitAgency.slug) + + assert result == model_TransitAgency + + +@pytest.mark.django_db +def test_TransitAgency_by_slug_nonmatching(): + result = TransitAgency.by_slug("nope") + + assert not result + + +@pytest.mark.django_db +def test_TransitAgency_all_active(model_TransitAgency): + count = TransitAgency.objects.count() + assert count >= 1 + + inactive_agency = TransitAgency.by_id(model_TransitAgency.id) + inactive_agency.pk = None + inactive_agency.active = False + inactive_agency.save() + + assert TransitAgency.objects.count() == count + 1 + + result = TransitAgency.all_active() + + assert len(result) > 0 + assert model_TransitAgency in result + assert inactive_agency not in result + + +@pytest.mark.django_db +def test_TransitAgency_for_user_in_group(model_TransitAgency): + group = Group.objects.create(name="test_group") + + agency_for_user = TransitAgency.by_id(model_TransitAgency.id) + agency_for_user.pk = None + agency_for_user.staff_group = group + agency_for_user.save() + + user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) + user.groups.add(group) + + assert TransitAgency.for_user(user) == agency_for_user + + +@pytest.mark.django_db +def test_TransitAgency_for_user_not_in_group(model_TransitAgency): + group = Group.objects.create(name="test_group") + + agency_for_user = TransitAgency.by_id(model_TransitAgency.id) + agency_for_user.pk = None + agency_for_user.staff_group = group + agency_for_user.save() + + user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) + + assert TransitAgency.for_user(user) is None + + +@pytest.mark.django_db +def test_TransitAgency_for_user_in_group_not_linked_to_any_agency(): + group = Group.objects.create(name="another test group") + + user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) + user.groups.add(group) + + assert TransitAgency.for_user(user) is None + + +@pytest.mark.django_db +def test_agency_logo_small(model_TransitAgency): + assert agency_logo_small(model_TransitAgency, "local_filename.png") == "agencies/test-sm.png" + + +@pytest.mark.django_db +def test_agency_logo_large(model_TransitAgency): + assert agency_logo_large(model_TransitAgency, "local_filename.png") == "agencies/test-lg.png" + + +@pytest.mark.django_db +def test_TransitAgency_clean(model_TransitAgency_inactive, model_TransitProcessor): + model_TransitAgency_inactive.transit_processor = model_TransitProcessor + + model_TransitAgency_inactive.short_name = "" + model_TransitAgency_inactive.long_name = "" + model_TransitAgency_inactive.phone = "" + model_TransitAgency_inactive.info_url = "" + model_TransitAgency_inactive.logo_large = "" + model_TransitAgency_inactive.logo_small = "" + model_TransitAgency_inactive.transit_processor_audience = "" + model_TransitAgency_inactive.transit_processor_client_id = "" + model_TransitAgency_inactive.transit_processor_client_secret_name = "" + # agency is inactive, OK to have incomplete fields + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError) as e: + model_TransitAgency_inactive.clean() + + errors = e.value.error_dict + + assert "short_name" in errors + assert "long_name" in errors + assert "phone" in errors + assert "info_url" in errors + assert "logo_large" in errors + assert "logo_small" in errors + assert "transit_processor_audience" in errors + assert "transit_processor_client_id" in errors + assert "transit_processor_client_secret_name" in errors + + +@pytest.mark.django_db +@pytest.mark.parametrize("template_attribute", ["index_template_override", "eligibility_index_template_override"]) +def test_TransitAgency_clean_templates(model_TransitAgency_inactive, template_attribute): + setattr(model_TransitAgency_inactive, template_attribute, "does/not/exist.html") + # agency is inactive, OK to have missing template + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError, match="Template not found: does/not/exist.html"): + model_TransitAgency_inactive.clean() + + +@pytest.mark.django_db +def test_TransitAgency_clean_dirty_flow(model_TransitAgency, model_EnrollmentFlow, model_ClaimsProvider): + # partially setup the EnrollmentFlow + # missing scope and claims + model_EnrollmentFlow.claims_provider = model_ClaimsProvider + model_EnrollmentFlow.transit_agency = model_TransitAgency + + # clean the agency, and expect an invalid EnrollmentFlow error + with pytest.raises(ValidationError, match=f"Invalid EnrollmentFlow: {model_EnrollmentFlow.label}"): + model_TransitAgency.clean()