diff --git a/src/analytics/api/models/events.py b/src/analytics/api/models/events.py index f378cc2e68..f527889cc8 100644 --- a/src/analytics/api/models/events.py +++ b/src/analytics/api/models/events.py @@ -44,9 +44,8 @@ def __init__(self, tenant): @cache.memoize() def download_from_bigquery( cls, - devices, - sites, - airqlouds, + filter_type, # Either 'devices', 'sites', or 'airqlouds' + filter_value, # The actual list of values for the filter type start_date, end_date, frequency, @@ -63,14 +62,15 @@ def download_from_bigquery( sorting_cols = ["site_id", "datetime", "device_name"] - if frequency == "raw": - data_table = cls.BIGQUERY_RAW_DATA - elif frequency == "daily": - data_table = cls.BIGQUERY_DAILY_DATA - elif frequency == "hourly": - data_table = cls.BIGQUERY_HOURLY_DATA - else: - raise Exception("Invalid frequency") + # Define table mapping for dynamic selection based on frequency + data_table = { + "raw": cls.BIGQUERY_RAW_DATA, + "daily": cls.BIGQUERY_DAILY_DATA, + "hourly": cls.BIGQUERY_HOURLY_DATA, + }.get(frequency) + + if not data_table: + raise ValueError("Invalid frequency") pollutant_columns = [] bam_pollutant_columns = [] @@ -116,7 +116,7 @@ def download_from_bigquery( f" FORMAT_DATETIME('%Y-%m-%d %H:%M:%S', {cls.BIGQUERY_BAM_DATA}.timestamp) AS datetime " ) - if len(devices) != 0: + if filter_type == "devices": # Adding device information, start and end times query = ( f" {pollutants_query} , " @@ -129,7 +129,7 @@ def download_from_bigquery( f" JOIN {devices_table} ON {devices_table}.device_id = {data_table}.device_id " f" WHERE {data_table}.timestamp >= '{start_date}' " f" AND {data_table}.timestamp <= '{end_date}' " - f" AND {devices_table}.device_id IN UNNEST({devices}) " + f" AND {devices_table}.device_id IN UNNEST(@filter_value) " ) bam_query = ( @@ -143,7 +143,7 @@ def download_from_bigquery( f" JOIN {devices_table} ON {devices_table}.device_id = {cls.BIGQUERY_BAM_DATA}.device_id " f" WHERE {cls.BIGQUERY_BAM_DATA}.timestamp >= '{start_date}' " f" AND {cls.BIGQUERY_BAM_DATA}.timestamp <= '{end_date}' " - f" AND {devices_table}.device_id IN UNNEST({devices}) " + f" AND {devices_table}.device_id IN UNNEST(@filter_value) " ) # Adding site information @@ -170,7 +170,7 @@ def download_from_bigquery( if frequency == "hourly": query = f"{query} UNION ALL {bam_query}" - elif len(sites) != 0: + elif filter_type == "sites": # Adding site information, start and end times query = ( f" {pollutants_query} , " @@ -184,7 +184,7 @@ def download_from_bigquery( f" JOIN {sites_table} ON {sites_table}.id = {data_table}.site_id " f" WHERE {data_table}.timestamp >= '{start_date}' " f" AND {data_table}.timestamp <= '{end_date}' " - f" AND {sites_table}.id IN UNNEST({sites}) " + f" AND {sites_table}.id IN UNNEST(@filter_value) " ) # Adding device information @@ -197,7 +197,7 @@ def download_from_bigquery( f" FROM {devices_table} " f" RIGHT JOIN ({query}) data ON data.device_name = {devices_table}.device_id " ) - else: + elif filter_type == "airqlouds": sorting_cols = ["airqloud_id", "site_id", "datetime", "device_name"] meta_data_query = ( @@ -205,7 +205,7 @@ def download_from_bigquery( f" {airqlouds_sites_table}.airqloud_id , " f" {airqlouds_sites_table}.site_id , " f" FROM {airqlouds_sites_table} " - f" WHERE {airqlouds_sites_table}.airqloud_id IN UNNEST({airqlouds}) " + f" WHERE {airqlouds_sites_table}.airqloud_id IN UNNEST(@filter_value) " ) # Adding airqloud information @@ -250,7 +250,11 @@ def download_from_bigquery( f" ORDER BY {data_table}.timestamp " ) - job_config = bigquery.QueryJobConfig() + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("filter_value", "STRING", filter_value), + ] + ) job_config.use_query_cache = True dataframe = ( bigquery.Client() diff --git a/src/analytics/api/utils/data_formatters.py b/src/analytics/api/utils/data_formatters.py index 3d281b8869..e609953e5e 100644 --- a/src/analytics/api/utils/data_formatters.py +++ b/src/analytics/api/utils/data_formatters.py @@ -320,7 +320,7 @@ def filter_non_private_sites(sites: List[str]) -> Dict[str, Any]: else: raise RuntimeError(response.get("message")) except RuntimeError as rex: - logger.exception(f"Error while filtering non private entities {rex}") + raise RuntimeError(f"Error while filtering non private sites {rex}") def filter_non_private_devices(devices: List[str]) -> Dict[str, Any]: @@ -348,5 +348,4 @@ def filter_non_private_devices(devices: List[str]) -> Dict[str, Any]: else: raise RuntimeError(response.get("message")) except RuntimeError as rex: - logger.exception(f"Error while filtering non private devices {rex}") - return {} + raise RuntimeError(f"Error while filtering non private devices {rex}") diff --git a/src/analytics/api/views/data.py b/src/analytics/api/views/data.py index e8848efc75..528bb228e4 100644 --- a/src/analytics/api/views/data.py +++ b/src/analytics/api/views/data.py @@ -1,5 +1,6 @@ import datetime import traceback +import logging import flask_excel as excel import pandas as pd @@ -33,6 +34,8 @@ from api.utils.request_validators import validate_request_json, validate_request_params from main import rest_api_v2 +logger = logging.getLogger(__name__) + @rest_api_v2.errorhandler(ExportRequestNotFound) def batch_not_found_exception(error): @@ -86,47 +89,42 @@ def post(self): start_date = json_data["startDateTime"] end_date = json_data["endDateTime"] - sites = filter_non_private_sites(sites=json_data.get("sites", {})).get( - "sites", [] - ) - devices = filter_non_private_devices(devices=json_data.get("devices", {})).get( - "devices", [] - ) - airqlouds = json_data.get("airqlouds", []) - weather_fields = json_data.get("weatherFields", None) - minimum_output = json_data.get("minimum", True) - frequency = self._get_valid_option( - json_data.get("frequency"), valid_options["frequencies"] - ) - download_type = self._get_valid_option( - json_data.get("downloadType"), valid_options["download_types"] - ) - output_format = self._get_valid_option( - json_data.get("outputFormat"), valid_options["output_formats"] - ) - data_type = self._get_valid_option( - json_data.get("datatype"), valid_options["data_types"] - ) - pollutants = json_data.get("pollutants", valid_options["pollutants"]) - if sum([len(sites) == 0, len(devices) == 0, len(airqlouds) == 0]) == 3: + try: + filter_type, filter_value = self._get_validated_filter(json_data) + except ValueError as e: return ( - AirQoRequests.create_response( - f"Specify either a list of airqlouds, sites or devices in the request body", - success=False, - ), + AirQoRequests.create_response(f"An error occured: {e}", success=False), AirQoRequests.Status.HTTP_400_BAD_REQUEST, ) - if sum([len(sites) != 0, len(devices) != 0, len(airqlouds) != 0]) != 1: + try: + frequency = self._get_valid_option( + json_data.get("frequency"), valid_options["frequencies"], "frequency" + ) + download_type = self._get_valid_option( + json_data.get("downloadType"), + valid_options["download_types"], + "downloadType", + ) + output_format = self._get_valid_option( + json_data.get("outputFormat"), + valid_options["output_formats"], + "outputFormat", + ) + data_type = self._get_valid_option( + json_data.get("datatype"), valid_options["data_types"], "datatype" + ) + except ValueError as e: return ( - AirQoRequests.create_response( - f"You cannot specify airqlouds, sites and devices in one go", - success=False, - ), + AirQoRequests.create_response(f"An error occured: {e}", success=False), AirQoRequests.Status.HTTP_400_BAD_REQUEST, ) + pollutants = json_data.get("pollutants", valid_options["pollutants"]) + weather_fields = json_data.get("weatherFields", None) + minimum_output = json_data.get("minimum", True) + if not all(p in valid_options["pollutants"] for p in pollutants): return ( AirQoRequests.create_response( @@ -140,9 +138,8 @@ def post(self): try: data_frame = EventsModel.download_from_bigquery( - sites=sites, - devices=devices, - airqlouds=airqlouds, + filter_type=filter_type, # Pass one filter[sites, airqlouds, devices] that has been passed in the api query + filter_value=filter_value, start_date=start_date, end_date=end_date, frequency=frequency, @@ -196,22 +193,67 @@ def post(self): AirQoRequests.Status.HTTP_500_INTERNAL_SERVER_ERROR, ) - def _get_valid_option(self, option, valid_options): + def _get_validated_filter(self, json_data): + """ + Ensures that only one of 'airqlouds', 'sites', or 'devices' is provided in the request. + Calls filter_non_private_* only after confirming exclusivity. + + Args: + json_data (dict): JSON payload from the request. + + Returns: + tuple: The name of the filter ("sites", "devices", or "airqlouds") and its validated value if valid. + + Raises: + ValueError: If more than one or none of the filters are provided. + """ + provided_filters = [ + key for key in ["sites", "devices", "airqlouds"] if json_data.get(key) + ] + + if len(provided_filters) != 1: + raise ValueError( + "Specify exactly one of 'airqlouds', 'sites', or 'devices' in the request body." + ) + + filter_type = provided_filters[0] + filter_value = json_data.get(filter_type) + + if filter_type == "sites": + validated_value = filter_non_private_sites(sites=filter_value).get( + "sites", [] + ) + elif filter_type == "devices": + validated_value = filter_non_private_devices(devices=filter_value).get( + "devices", [] + ) + else: + # No additional processing is needed for 'airqlouds' + validated_value = filter_value + + return filter_type, validated_value + + def _get_valid_option(self, option, valid_options, option_name): """ - Returns a validated option, defaulting to the first valid option if not provided or invalid. + Returns a validated option, raising an error with valid options if invalid. Args: option (str): Option provided in the request. valid_options (list): List of valid options. + option_name (str): The name of the option being validated. Returns: str: A validated option from the list. + + Raises: + ValueError: If the provided option is invalid. """ - return ( - option.lower() - if option and option.lower() in valid_options - else valid_options[0] - ) + if option and option.lower() in valid_options: + return option.lower() + if option: + raise ValueError( + f"Invalid {option_name}. Valid values are: {', '.join(valid_options)}." + ) @rest_api_v2.route("/data-export") diff --git a/src/analytics/requirements.txt b/src/analytics/requirements.txt index e3f5c21b85..f1e5911031 100644 --- a/src/analytics/requirements.txt +++ b/src/analytics/requirements.txt @@ -18,6 +18,7 @@ python-decouple celery google-cloud-storage gunicorn +google-cloud-bigquery-storage==2.27.0 # Ports for stable python 3 functionality dataclasses~=0.6