From 25f4ee470b847ff388dd454c5e1b3c88e3e7edf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Farcy?= Date: Thu, 8 Jun 2023 11:21:21 +0200 Subject: [PATCH] Refactoring source reporting - Updating & saving the SourceReport instance only in refresh_data method - Use exception in lower methods to handle errors & to raise them to upper method (refresh_data) --- project/geosource/models.py | 131 +++++++++--------- project/geosource/serializers.py | 4 +- project/geosource/tasks.py | 2 +- project/geosource/tests/test_api.py | 6 +- .../geosource/tests/test_model_exceptions.py | 32 ++--- project/geosource/tests/test_models.py | 50 ++++--- project/geosource/tests/test_tasks.py | 61 +++++++- 7 files changed, 168 insertions(+), 118 deletions(-) diff --git a/project/geosource/models.py b/project/geosource/models.py index 8899ea18..b8488f31 100644 --- a/project/geosource/models.py +++ b/project/geosource/models.py @@ -28,6 +28,7 @@ from .mixins import CeleryCallMethodsMixin from .signals import refresh_data_done + # Decimal fields must be returned as float DEC2FLOAT = psycopg2.extensions.new_type( psycopg2.extensions.DECIMAL.values, @@ -63,6 +64,19 @@ def get_type_from_data(cls, data): return types.get(type(data), cls.Undefined) +class SourceException(Exception): + """Generic source exception to be catched by the generic Source Model""" + + def __init__(self, message): + self.message = message + + +class CSVSourceException(SourceException): + """CSVSource exception raised by the CSVSource model""" + + pass + + class SourceReporting(models.Model): class Status(models.IntegerChoices): SUCCESS = 0, "Success" @@ -161,6 +175,12 @@ def refresh_data(self): es_index.index() response = self._refresh_data(es_index) return response + except SourceException as exc: + self.report.status = self.report.Status.ERROR.value + self.report.message = exc.message + self.report.ended = timezone.now() + self.report.save(update_fields=["status", "message", "ended"]) + finally: self.last_refresh = timezone.now() self.status = self.Status.DONE.value @@ -184,7 +204,9 @@ def _refresh_data(self, es_index=None): added_rows = 0 modified_rows = 0 total = 0 - for i, row in enumerate(self._get_records()): + records, records_errors = self._get_records() + self.report.errors += records_errors + for i, row in enumerate(records): total += 1 geometry = row.pop(self.SOURCE_GEOM_ATTRIBUTE) try: @@ -204,15 +226,12 @@ def _refresh_data(self, es_index=None): except Exception: transaction.savepoint_rollback(sid) msg = "An error occured on feature(s)" - self.report.status = SourceReporting.Status.WARNING.value self.report.errors.append( f"Row n°{i}: {msg}. {self.id_field} - {identifier}" ) continue except KeyError: - msg = "Can't find identifier field for this record" - self.report.status = SourceReporting.Status.WARNING.value - self.report.errors.append(f"Line {i}: {msg}") + self.report.errors.append(f"Line {i}: Can't find identifier field for this record") continue row_count += 1 deleted, _ = self.clear_features(layer, begin_date) @@ -223,15 +242,12 @@ def _refresh_data(self, es_index=None): if not row_count: self.report.status = SourceReporting.Status.ERROR.value self.report.message = "Failed to refresh data" - # Pending is added at the start of a refresh, so if not pending, it has to be Warning or Error - # From _get_records(), called earlier - elif ( - row_count == total - and self.report.status == SourceReporting.Status.PENDING.value - ): + elif row_count == total: self.report.status = SourceReporting.Status.SUCCESS.value + self.report.message = "Source refreshed successfully" else: self.report.status = SourceReporting.Status.WARNING.value + self.report.message = "Source refreshed partially" if self.id: self.report.ended = timezone.now() self.report.save() @@ -239,7 +255,7 @@ def _refresh_data(self, es_index=None): @transaction.atomic def update_fields(self): - records = self._get_records(50) + records, _ = self._get_records(50) fields = {} @@ -401,16 +417,12 @@ def get_file_as_dict(self): raise def _get_records(self, limit=None): - if not self.report: - self.report = SourceReporting( - started=timezone.now(), status=SourceReporting.Status.PENDING.value - ) - geojson = self.get_file_as_dict() limit = limit if limit else len(geojson["features"]) records = [] + errors = [] for i, record in enumerate(geojson["features"][:limit]): try: records.append( @@ -422,13 +434,10 @@ def _get_records(self, limit=None): } ) except (ValueError, GDALException): - msg = "The record geometry seems invalid." - self.report.status = SourceReporting.Status.WARNING.value - self.report.errors.append(f"Line {i}: {msg}") - if self.id: - self.report.ended = timezone.now() - self.report.save() - return records + feature_id = record.get("properties", {}).get("id", i) + msg = f"The record geometry seems invalid for feature {feature_id}." + errors.append(msg) + return (records, errors) class ShapefileSource(Source): @@ -443,7 +452,7 @@ def _get_records(self, limit=None): _, srid = shapefile.crs.get("init", "epsg:4326").split(":") # Return geometries with a hack to set the correct geometry srid - return [ + records = [ { self.SOURCE_GEOM_ATTRIBUTE: GEOSGeometry( GEOSGeometry(json.dumps(feature.get("geometry"))).wkt, @@ -453,6 +462,8 @@ def _get_records(self, limit=None): } for feature in shapefile[:limit] ] + # No errors catched for Shapefile + return (records, []) class CommandSource(Source): @@ -482,7 +493,7 @@ def _refresh_data(self, es_index=None): return {"count": layer.features.count()} def _get_records(self, limit=None): - return [] + return [None, None] class WMTSSource(Source): @@ -498,7 +509,7 @@ def refresh_data(self): return {} def _get_records(self, limit=None): - return [] + return [None, None] class CSVSource(Source): @@ -536,10 +547,6 @@ def get_file_as_sheet(self): raise def _get_records(self, limit=None): - # _get_records is used in the serializer to validate data - if not self.id or not self.report: - self.report = SourceReporting(started=timezone.now()) - sheet = self.get_file_as_sheet() if self.settings.get("use_header"): sheet.name_columns_by_row(0) @@ -551,18 +558,22 @@ def _get_records(self, limit=None): limit = limit if limit else len(sheet) records = [] + errors = [] srid = self._get_srid() - row_count = 0 - total = 0 + for i, row in enumerate(sheet): - total += 1 if self.settings["coordinates_field"] == "two_columns": lat_field = self.settings["latitude_field"] lng_field = self.settings["longitude_field"] - x, y = self._extract_coordinates( - row, sheet.colnames, [lng_field, lat_field] - ) + try: + x, y = self._extract_coordinates( + row, sheet.colnames, [lng_field, lat_field] + ) + except CSVSourceException as e: + errors.append(f"Sheet row {i} - {e.message}") + continue + ignored_field = (row.index(x), row.index(y), *ignored_columns) else: lnglat_field = self.settings["latlong_field"] @@ -570,8 +581,10 @@ def _get_records(self, limit=None): x, y = self._extract_coordinates( row, sheet.colnames, [lnglat_field] ) - except ValueError: + except CSVSourceException as e: + errors.append(f"Sheet row {i} - {e.message}") continue + coord_fields = ( (sheet.colnames.index(lnglat_field),) if self.settings.get("use_header") @@ -590,23 +603,15 @@ def _get_records(self, limit=None): } ) except (ValueError, GDALException): - msg = f"One of source's record has invalid geometry: Point({x} {y}) srid={srid}" - self.report.status = SourceReporting.Status.WARNING.value - self.report.errors.append(f"Line {i}: {msg}") + errors.append( + f"Sheet row {i} - One of source's record has invalid geometry: Point({x} {y}) srid={srid}" + ) continue - row_count += 1 - if not row_count: - self.report.status = SourceReporting.Status.ERROR.value - self.report.message = "No record could be imported, check the report" - elif row_count == total: - self.report.status = SourceReporting.Status.SUCCESS.value - if self.id: - self.report.save() - return records + return (records, errors) - def _extract_coordinates(self, row, colnames, fields): + def _extract_coordinates(self, row, colnames, coord_fields): coords = [] - for field in fields: + for field in coord_fields: # if no header, we expect index for the columns has been provided try: field_index = ( @@ -614,15 +619,9 @@ def _extract_coordinates(self, row, colnames, fields): if self.settings.get("use_header") else int(field) ) - except ValueError as err: - msg = f"{field} is not a valid coordinate field" - if not self.report: - self.report = SourceReporting(started=timezone.now()) - self.report.status = SourceReporting.Status.WARNING.value - self.report.errors.append(msg) - self.report.save() - err.args = (msg,) - raise + except ValueError: + raise CSVSourceException(f"{field} is not a valid coordinate field") + c = row[field_index] coords.append(c) if len(coords) == 2: @@ -674,14 +673,8 @@ def _get_srid(self): coordinate_reference_system = self.settings["coordinate_reference_system"] try: return int(coordinate_reference_system.split("_")[1]) - except (IndexError, ValueError) as err: - msg = f"Invalid SRID: {coordinate_reference_system}" - self.report.status = SourceReporting.Status.ERROR.value - self.report.message = msg - # self.report.setdefault("message", []).append(msg) - self.report.save() - err.args = (msg,) - raise + except (IndexError, ValueError): + raise CSVSourceException(f"Invalid SRID: {coordinate_reference_system}") # properties are use by serializer for representation (reading operation) @property diff --git a/project/geosource/serializers.py b/project/geosource/serializers.py index 829aae14..1a5df5c0 100644 --- a/project/geosource/serializers.py +++ b/project/geosource/serializers.py @@ -285,7 +285,7 @@ def _validate_field_infos(self, data): except TypeError: return # file field is empty in update no get_records try: - records = instance._get_records(1) + records, errors = instance._get_records(1) except Exception as err: raise ValidationError(err.args[0]) @@ -439,7 +439,7 @@ def _validate_field_infos(self, data): # create an instance without saving data instance = self.Meta.model(**data_copy) try: - records = instance._get_records(1) + records, errors = instance._get_records(1) except (ValueError, GDALException) as err: raise ValidationError(err.args[0]) diff --git a/project/geosource/tasks.py b/project/geosource/tasks.py index 1dbdb1f3..6a2f8168 100644 --- a/project/geosource/tasks.py +++ b/project/geosource/tasks.py @@ -44,7 +44,7 @@ def run_model_object_method(self, app, model, pk, method, success_state=states.S self.update_state(state=success_state, meta=state) if obj and hasattr(obj, "report"): - if not obj.report: + if obj.report is not None: text = "" for key, value in state.items(): text += f"{key}: {value}," diff --git a/project/geosource/tests/test_api.py b/project/geosource/tests/test_api.py index ed182261..29e076e6 100644 --- a/project/geosource/tests/test_api.py +++ b/project/geosource/tests/test_api.py @@ -267,9 +267,9 @@ def run_sync_method_result(cmd, success_state): @patch( "project.geosource.models.Source._get_records", MagicMock( - return_value=[ + return_value=([ {"a": "b", "c": 42, "d": b"4", "e": b"\xe8", "_geom_": "POINT(0 0)"} - ] + ], []) ), ) def test_update_fields_method(self): @@ -283,7 +283,7 @@ def test_update_fields_method(self): @patch( "project.geosource.models.Source._get_records", - MagicMock(return_value=[{"a": "b", "c": 42, "_geom_": "POINT(0 0)"}]), + MagicMock(return_value=([{"a": "b", "c": 42, "_geom_": "POINT(0 0)"}], [])), ) def test_update_fields_with_delete_method(self): obj = Source.objects.create(geom_type=10) diff --git a/project/geosource/tests/test_model_exceptions.py b/project/geosource/tests/test_model_exceptions.py index d29a6697..44e6782b 100644 --- a/project/geosource/tests/test_model_exceptions.py +++ b/project/geosource/tests/test_model_exceptions.py @@ -14,6 +14,7 @@ PostGISSource, ShapefileSource, SourceReporting, + CSVSourceException, ) from project.geosource.tests.helpers import get_file @@ -47,11 +48,8 @@ def test_csv_with_wrong_x_coord(self, mocked_es_delete, mocked_es_create): "latitude_field": "YCOORDS", }, ) - msg = "X is not a valid coordinate field" - with self.assertRaisesMessage(ValueError, msg): - source._get_records() - self.assertEqual(source.report.status, SourceReporting.Status.Warning.value) - self.assertIn(msg, source.report.get("message", [])) + _, errors = source._get_records() + self.assertIn("Sheet row 0 - X is not a valid coordinate field", errors) def test_csv_with_wrong_y_coord(self, mocked_es_delete, mocked_es_create): source = CSVSource.objects.create( @@ -70,10 +68,8 @@ def test_csv_with_wrong_y_coord(self, mocked_es_delete, mocked_es_create): "latitude_field": "Y", # Wrong on purpose }, ) - msg = "Y is not a valid coordinate field" - with self.assertRaisesMessage(ValueError, msg): - source._get_records() - self.assertIn(msg, source.report.get("message", [])) + _, errors = source._get_records() + self.assertIn("Sheet row 0 - Y is not a valid coordinate field", errors) def test_invalid_csv_file_raise_value_error( self, mocked_es_delete, mocked_es_create @@ -104,7 +100,7 @@ def test_invalid_csv_file_raise_value_error( self.assertIsInstance(source.report, SourceReporting) self.assertIn(msg, source.report.get("message", [])) - def test_invalid_coordinate_format_raise_error( + def test_invalid_coordinate_format_error_handle( self, mocked_es_delete, mocked_es_create ): source = CSVSource.objects.create( @@ -124,8 +120,8 @@ def test_invalid_coordinate_format_raise_error( "coordinates_field_count": "xy", }, ) - source._get_records() - self.assertIn("coordxy is not a valid coordinate field", source.report.errors) + _, errors = source._get_records() + self.assertIn("Sheet row 0 - coordxy is not a valid coordinate field", errors) def test_coordinates_system_without_digit_srid_raise_value_error( self, mocked_es_delete, mocked_es_create @@ -146,7 +142,7 @@ def test_coordinates_system_without_digit_srid_raise_value_error( "latitude_field": "YCOORD", }, ) - with self.assertRaises(ValueError): + with self.assertRaisesMessage(CSVSourceException, "Invalid SRID: EPSG_SRID"): source._get_records() def test_coordinates_systems_malformed_raise_index_error( @@ -168,7 +164,7 @@ def test_coordinates_systems_malformed_raise_index_error( "latitude_field": "YCOORD", }, ) - with self.assertRaises(IndexError): + with self.assertRaisesMessage(CSVSourceException, "Invalid SRID: 4326"): source._get_records() # TODO: Move to test_models.py instead, since no exception is raised anymore @@ -238,9 +234,9 @@ def test_gdal_exception_set_report_to_warning( "latitude_field": "YCOORD", }, ) - source._get_records() - msg = "Line 0: One of source's record has invalid geometry: Point(930077.50743 6922202.67316) srid=4326" - self.assertIn(msg, source.report.errors) + _, errors = source._get_records() + msg = "Sheet row 0 - One of source's record has invalid geometry: Point(930077.50743 6922202.67316) srid=4326" + self.assertIn(msg, errors) @patch("project.geosource.models.pyexcel.get_sheet", side_effect=Exception()) def test_get_file_as_sheet_exception_create_new_report_if_none( @@ -278,7 +274,7 @@ def test_valueerror_raised_in_extract_coordinate_create_report_if_none( id_field="identifier", ) self.assertIsNone(source.report) - with self.assertRaises(ValueError): + with self.assertRaises(CSVSourceException): # put some nonsens data to trigger ValueError raise source._extract_coordinates( ["a", "b", "c"], [1, 2, 3], ["foo", "bar", "foobar"] diff --git a/project/geosource/tests/test_models.py b/project/geosource/tests/test_models.py index 97260214..806ab845 100644 --- a/project/geosource/tests/test_models.py +++ b/project/geosource/tests/test_models.py @@ -208,10 +208,10 @@ def test_get_records_wrong_geom_file(self): ) # with self.assertRaises(ValueError) as m: - source._get_records(1) + records, errors = source._get_records(1) self.assertIn( - "Line 0: The record geometry seems invalid.", - source.report.errors, + "The record geometry seems invalid for feature 1.", + errors, ) @@ -223,7 +223,7 @@ def test_get_records(self): file=get_file("test.zip"), ) - records = source._get_records(1) + records, errors = source._get_records(1) self.assertEqual(records[0]["NOM"], "Trifouilli-les-Oies") self.assertEqual(records[0]["Insee"], 99999) self.assertEqual(records[0]["_geom_"].geom_typeid, GeometryTypes.Polygon) @@ -244,7 +244,7 @@ def test_refresh_data(self, mocked_stdout, mock_index): self.assertIn("Start refresh", mocked_stdout.getvalue()) def test_get_records(self): - self.assertEqual([], self.source._get_records()) + self.assertEqual([None, None], self.source._get_records()) class ModelWMTSSourceTestCase(TestCase): @@ -257,7 +257,7 @@ def setUp(self): ) def test_get_records(self): - self.assertEqual([], self.source._get_records()) + self.assertEqual([None, None], self.source._get_records()) def test_get_status(self): self.assertEqual({"state": "DONT_NEED"}, self.source.get_status()) @@ -297,9 +297,9 @@ def test_get_records_with_two_columns_coordinates( }, ) - records = source._get_records() + records, errors = source._get_records() self.assertEqual(len(records), 6, len(records)) - with self.assertNumQueries(72): + with self.assertNumQueries(70): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -323,9 +323,9 @@ def test_get_records_with_one_column_coordinates( }, ) - records = source._get_records() + records, errors = source._get_records() self.assertEqual(len(records), 9, len(records)) - with self.assertNumQueries(96): + with self.assertNumQueries(94): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -353,9 +353,9 @@ def test_get_records_with_decimal_separator_as_comma( "coordinates_field_count": "xy", }, ) - records = source._get_records() + records, errors = source._get_records() self.assertEqual(len(records), 9, len(records)) - with self.assertNumQueries(96): + with self.assertNumQueries(94): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -378,7 +378,7 @@ def test_get_records_with_nulled_columns_ignored( "latitude_field": "YCOORD", }, ) - records = source._get_records() + records, errors = source._get_records() # this entry as an empty column and should not be in records empty_entry = [ record.get("photoEtablissement") @@ -386,7 +386,7 @@ def test_get_records_with_nulled_columns_ignored( if record.get("photoEtablissement") ] self.assertEqual(len(empty_entry), 0, empty_entry) - with self.assertNumQueries(72): + with self.assertNumQueries(70): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -410,9 +410,9 @@ def test_get_records_with_no_header_and_yx_csv( "coordinates_field_count": "yx", }, ) - records = source._get_records() + records, errors = source._get_records() self.assertEqual(len(records), 9, len(records)) - with self.assertNumQueries(96): + with self.assertNumQueries(94): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -436,9 +436,9 @@ def test_get_records_with_no_header_and_two_columns_csv( "longitude_field": "0", }, ) - records = source._get_records() + records, errors = source._get_records() self.assertEqual(len(records), 10, len(records)) - with self.assertNumQueries(104): + with self.assertNumQueries(102): row_count = source.refresh_data() self.assertEqual(row_count["count"], len(records), row_count) @@ -458,7 +458,7 @@ def test_update_fields_keep_order(self): sheet = source.get_file_as_sheet() sheet.name_columns_by_row(0) colnames = [name for name in sheet.colnames if name not in ("XCOORD", "YCOORD")] - with self.assertNumQueries(699): + with self.assertNumQueries(698): source.update_fields() fields = [f.name for f in Field.objects.filter(source=source)] self.assertTrue(fields == colnames) @@ -564,13 +564,17 @@ def test_partial_refresh_trigger_warning( The report Status shoule be WARNING""" # Mocking _get_records to return some incorret row - mocked_rows = [ - {"_geom_": Point(2, 42, srid=4326), "id": 1, "test": 5}, - {"_geom_": "wrong geom"}, - ] + mocked_rows = ( + [ + {"_geom_": Point(2, 42, srid=4326), "id": 1, "test": 5}, + {"_geom_": "wrong geom"}, + ], + [], + ) self.source._get_records = mock.MagicMock(return_value=mocked_rows) self.source.refresh_data() self.assertEqual( self.source.report.status, SourceReporting.Status.WARNING.value, + self.source.report.get_status_display(), ) diff --git a/project/geosource/tests/test_tasks.py b/project/geosource/tests/test_tasks.py index 66f6c5e0..a1bf1c27 100644 --- a/project/geosource/tests/test_tasks.py +++ b/project/geosource/tests/test_tasks.py @@ -1,12 +1,13 @@ import logging +from datetime import datetime from unittest import mock from django.contrib.auth.models import Group from django.test import TestCase from geostore.models import Feature, Layer -from project.geosource.models import GeoJSONSource, GeometryTypes -from project.geosource.tasks import run_model_object_method +from project.geosource.models import GeoJSONSource, GeometryTypes, SourceReporting +from project.geosource.tasks import run_model_object_method, set_failure_state from project.geosource.tests.helpers import get_file @@ -83,3 +84,59 @@ def test_task_good_method_error( ) ) self.assertEqual(Layer.objects.count(), 0) + + def test_run_model_object_update_report( + self, mock_index_feature, mock_es_delete, mock_es_create + ): + source = self.element + source_report = SourceReporting.objects.create( + status=SourceReporting.Status.PENDING.value, + ) + source.report = source_report + source.id = None + source.save() + + self.assertIsNone(source.report.ended) + run_model_object_method.apply( + ( + source._meta.app_label, + source._meta.model_name, + source.pk, + "refresh_data", + ) + ) + source.report.refresh_from_db() + self.assertIsInstance(source.report.message, str) + self.assertIsInstance(source.report.ended, datetime) + + def test_set_failure_state_task_update_report( + self, + mock_index_feature, + mock_es_delete, + mock_es_create, + ): + logging.disable(logging.ERROR) + source_report = SourceReporting.objects.create( + status=SourceReporting.Status.PENDING.value, + ) + source = GeoJSONSource.objects.create( + name="exception-test", + geom_type=GeometryTypes.Point, + file=get_file("test.geojson"), + report=source_report, + ) + + self.assertIsNone(source.report.ended) + try: + run_model_object_method.apply( + ( + source._meta.app_label, + source._meta.model_name, + source.pk, + "method_that_does_not_exist", + ) + ) + except AttributeError: + source.report.refresh_from_db() + self.assertIsInstance(source.report.message, str) + self.assertIsInstance(source.report.ended, datetime)