diff --git a/app/helpers/utils.py b/app/helpers/utils.py index 7bacad1..2741a7b 100644 --- a/app/helpers/utils.py +++ b/app/helpers/utils.py @@ -18,6 +18,7 @@ from flask import request from app.settings import KML_FILE_CONTENT_TYPE +from app.settings import KML_MAX_SIZE from app.settings import KML_STORAGE_HOST_URL logger = logging.getLogger(__name__) @@ -78,24 +79,14 @@ def wrapped(*args, **kwargs): return inner_decorator -def validate_content_length(max_length): - - def inner_decorator(func): - - @wraps(func) - def wrapped(*args, **kwargs): - if request.content_length > max_length: - logger.error( - 'Payload too large: payload=%s MB, max_allowed=%s MB', - bytes_conversion(request.content_length, 'MB'), - bytes_conversion(max_length, 'MB'), - ) - abort(413, "Payload too large") - return func(*args, **kwargs) - - return wrapped - - return inner_decorator +def validate_content_length(file_content, max_length): + if len(file_content) > max_length: + logger.error( + 'Payload too large: payload=%s MB, max_allowed=%s MB', + bytes_conversion(request.content_length, 'MB'), + bytes_conversion(max_length, 'MB'), + ) + abort(413, "KML file too large") def validate_kml_string(kml_string): @@ -147,7 +138,9 @@ def validate_kml_file(): KML_FILE_CONTENT_TYPE ) abort(415, "Unsupported KML media type") - file_content = decompress_if_gzipped(file) + file_content = file.read() + validate_content_length(file_content, KML_MAX_SIZE) + file_content = decompress_if_gzipped(file_content) try: if 'charset' in file.mimetype_params: quoted_data = file_content.decode(file.mimetype_params['charset']) @@ -196,10 +189,9 @@ def gzip_string(string): return gzipped_data -def decompress_if_gzipped(file): +def decompress_if_gzipped(file_content): '''Returns the file content as bytes object, after unzipping the file if necessary''' - file_content = file.read() try: ret = gzip.decompress(file_content) except OSError as error: diff --git a/app/routes.py b/app/routes.py index 97c94a8..d8cf7d8 100644 --- a/app/routes.py +++ b/app/routes.py @@ -14,11 +14,9 @@ from app.helpers.dynamodb import get_db from app.helpers.s3 import get_storage from app.helpers.utils import get_kml_file_link -from app.helpers.utils import validate_content_length from app.helpers.utils import validate_content_type from app.helpers.utils import validate_kml_file from app.helpers.utils import validate_permissions -from app.settings import KML_MAX_SIZE from app.settings import SCRIPT_NAME from app.version import APP_VERSION @@ -35,7 +33,6 @@ def checker(): @app.route('/admin', methods=['POST']) @validate_content_type("multipart/form-data") -@validate_content_length(KML_MAX_SIZE) def create_kml(): # Get the kml file data kml_string_gzip, empty = validate_kml_file() @@ -124,7 +121,6 @@ def get_kml_metadata(kml_id): @app.route('/admin/', methods=['PUT']) @validate_content_type("multipart/form-data") -@validate_content_length(KML_MAX_SIZE) def update_kml(kml_id): db = get_db() diff --git a/tests/unit_tests/base.py b/tests/unit_tests/base.py index b380d8c..215741d 100644 --- a/tests/unit_tests/base.py +++ b/tests/unit_tests/base.py @@ -198,7 +198,8 @@ def assertKml(self, response, expected_kml_file): expected_kml_path = f'./tests/samples/{expected_kml_file}' # read the expected kml file with open(expected_kml_path, 'rb') as fd: - expected_kml = decompress_if_gzipped(fd).decode('utf-8') + content = fd.read() + expected_kml = decompress_if_gzipped(content).decode('utf-8') kml_id = response.json['id'] item = self.dynamodb.Table(AWS_DB_TABLE_NAME).get_item(Key={ 'kml_id': kml_id @@ -230,7 +231,7 @@ def assertKml(self, response, expected_kml_file): else: self.fail(f'S3 client error: {error}') - body = decompress_if_gzipped((obj['Body'])) + body = decompress_if_gzipped((obj['Body'].read())) self.assertEqual(body.decode('utf-8'), expected_kml) def get_s3_object(self, file_key): diff --git a/tests/unit_tests/test_routes.py b/tests/unit_tests/test_routes.py index cb05cee..e71e2ce 100644 --- a/tests/unit_tests/test_routes.py +++ b/tests/unit_tests/test_routes.py @@ -3,6 +3,7 @@ import logging import uuid from datetime import timedelta +from unittest.mock import patch from flask import url_for @@ -49,6 +50,19 @@ def test_valid_gzipped_kml_post(self): self.assertEqual(response.content_type, "application/json") # pylint: disable=no-member self.assertKml(response, kml_file) + @patch('app.helpers.utils.KML_MAX_SIZE', 10) + def test_too_big_kml_post(self): + kml_file = 'valid-kml.xml' + response = self.app.post( + url_for('create_kml'), + data=prepare_kml_payload(kml_file=kml_file), + content_type="multipart/form-data", + headers=self.origin_headers["allowed"] + ) + self.assertEqual(response.status_code, 413) + self.assertCors(response, ['GET', 'HEAD', 'POST', 'OPTIONS']) + self.assertEqual(response.content_type, "application/json") # pylint: disable=no-member + def test_invalid_kml_post(self): response = self.app.post( url_for('create_kml'),