Skip to content

Commit

Permalink
Refactoring source reporting
Browse files Browse the repository at this point in the history
- Updating & saving the SourceReport instance only in refresh_data method
 - Use exception in lower methods to handle errors & to raise them to upper method (refresh_data)
  • Loading branch information
Cédric Farcy committed Jul 10, 2023
1 parent 5e55135 commit 9e7ebdf
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 118 deletions.
132 changes: 63 additions & 69 deletions project/geosource/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def get_type_from_data(cls, data):
return types.get(type(data), cls.Undefined)


class SourceException(Exception):
"""Generic source exception to be catched by the generic Source Model"""

def __init__(self, message):
self.message = message


class CSVSourceException(SourceException):
"""CSVSource exception raised by the CSVSource model"""

pass


class SourceReporting(models.Model):
class Status(models.IntegerChoices):
SUCCESS = 0, "Success"
Expand Down Expand Up @@ -161,6 +174,12 @@ def refresh_data(self):
es_index.index()
response = self._refresh_data(es_index)
return response
except SourceException as exc:
self.report.status = self.report.Status.ERROR.value
self.report.message = exc.message
self.report.ended = timezone.now()
self.report.save(update_fields=["status", "message", "ended"])

finally:
self.last_refresh = timezone.now()
self.status = self.Status.DONE.value
Expand All @@ -184,7 +203,9 @@ def _refresh_data(self, es_index=None):
added_rows = 0
modified_rows = 0
total = 0
for i, row in enumerate(self._get_records()):
records, records_errors = self._get_records()
self.report.errors += records_errors
for i, row in enumerate(records):
total += 1
geometry = row.pop(self.SOURCE_GEOM_ATTRIBUTE)
try:
Expand All @@ -204,15 +225,14 @@ def _refresh_data(self, es_index=None):
except Exception:
transaction.savepoint_rollback(sid)
msg = "An error occured on feature(s)"
self.report.status = SourceReporting.Status.WARNING.value
self.report.errors.append(
f"Row n°{i}: {msg}. {self.id_field} - {identifier}"
)
continue
except KeyError:
msg = "Can't find identifier field for this record"
self.report.status = SourceReporting.Status.WARNING.value
self.report.errors.append(f"Line {i}: {msg}")
self.report.errors.append(
f"Line {i}: Can't find identifier field for this record"
)
continue
row_count += 1
deleted, _ = self.clear_features(layer, begin_date)
Expand All @@ -223,23 +243,20 @@ def _refresh_data(self, es_index=None):
if not row_count:
self.report.status = SourceReporting.Status.ERROR.value
self.report.message = "Failed to refresh data"
# Pending is added at the start of a refresh, so if not pending, it has to be Warning or Error
# From _get_records(), called earlier
elif (
row_count == total
and self.report.status == SourceReporting.Status.PENDING.value
):
elif row_count == total:
self.report.status = SourceReporting.Status.SUCCESS.value
self.report.message = "Source refreshed successfully"
else:
self.report.status = SourceReporting.Status.WARNING.value
self.report.message = "Source refreshed partially"
if self.id:
self.report.ended = timezone.now()
self.report.save()
return {"count": row_count, "total": total}

@transaction.atomic
def update_fields(self):
records = self._get_records(50)
records, _ = self._get_records(50)

fields = {}

Expand Down Expand Up @@ -401,16 +418,12 @@ def get_file_as_dict(self):
raise

def _get_records(self, limit=None):
if not self.report:
self.report = SourceReporting(
started=timezone.now(), status=SourceReporting.Status.PENDING.value
)

geojson = self.get_file_as_dict()

limit = limit if limit else len(geojson["features"])

records = []
errors = []
for i, record in enumerate(geojson["features"][:limit]):
try:
records.append(
Expand All @@ -422,13 +435,10 @@ def _get_records(self, limit=None):
}
)
except (ValueError, GDALException):
msg = "The record geometry seems invalid."
self.report.status = SourceReporting.Status.WARNING.value
self.report.errors.append(f"Line {i}: {msg}")
if self.id:
self.report.ended = timezone.now()
self.report.save()
return records
feature_id = record.get("properties", {}).get("id", i)
msg = f"The record geometry seems invalid for feature {feature_id}."
errors.append(msg)
return (records, errors)


class ShapefileSource(Source):
Expand All @@ -443,7 +453,7 @@ def _get_records(self, limit=None):
_, srid = shapefile.crs.get("init", "epsg:4326").split(":")

# Return geometries with a hack to set the correct geometry srid
return [
records = [
{
self.SOURCE_GEOM_ATTRIBUTE: GEOSGeometry(
GEOSGeometry(json.dumps(feature.get("geometry"))).wkt,
Expand All @@ -453,6 +463,8 @@ def _get_records(self, limit=None):
}
for feature in shapefile[:limit]
]
# No errors catched for Shapefile
return (records, [])


class CommandSource(Source):
Expand Down Expand Up @@ -482,7 +494,7 @@ def _refresh_data(self, es_index=None):
return {"count": layer.features.count()}

def _get_records(self, limit=None):
return []
return [None, None]


class WMTSSource(Source):
Expand All @@ -498,7 +510,7 @@ def refresh_data(self):
return {}

def _get_records(self, limit=None):
return []
return [None, None]


class CSVSource(Source):
Expand Down Expand Up @@ -536,10 +548,6 @@ def get_file_as_sheet(self):
raise

def _get_records(self, limit=None):
# _get_records is used in the serializer to validate data
if not self.id or not self.report:
self.report = SourceReporting(started=timezone.now())

sheet = self.get_file_as_sheet()
if self.settings.get("use_header"):
sheet.name_columns_by_row(0)
Expand All @@ -551,27 +559,33 @@ def _get_records(self, limit=None):
limit = limit if limit else len(sheet)

records = []
errors = []
srid = self._get_srid()
row_count = 0
total = 0

for i, row in enumerate(sheet):
total += 1
if self.settings["coordinates_field"] == "two_columns":
lat_field = self.settings["latitude_field"]
lng_field = self.settings["longitude_field"]

x, y = self._extract_coordinates(
row, sheet.colnames, [lng_field, lat_field]
)
try:
x, y = self._extract_coordinates(
row, sheet.colnames, [lng_field, lat_field]
)
except CSVSourceException as e:
errors.append(f"Sheet row {i} - {e.message}")
continue

ignored_field = (row.index(x), row.index(y), *ignored_columns)
else:
lnglat_field = self.settings["latlong_field"]
try:
x, y = self._extract_coordinates(
row, sheet.colnames, [lnglat_field]
)
except ValueError:
except CSVSourceException as e:
errors.append(f"Sheet row {i} - {e.message}")
continue

coord_fields = (
(sheet.colnames.index(lnglat_field),)
if self.settings.get("use_header")
Expand All @@ -590,39 +604,25 @@ def _get_records(self, limit=None):
}
)
except (ValueError, GDALException):
msg = f"One of source's record has invalid geometry: Point({x} {y}) srid={srid}"
self.report.status = SourceReporting.Status.WARNING.value
self.report.errors.append(f"Line {i}: {msg}")
errors.append(
f"Sheet row {i} - One of source's record has invalid geometry: Point({x} {y}) srid={srid}"
)
continue
row_count += 1
if not row_count:
self.report.status = SourceReporting.Status.ERROR.value
self.report.message = "No record could be imported, check the report"
elif row_count == total:
self.report.status = SourceReporting.Status.SUCCESS.value
if self.id:
self.report.save()
return records
return (records, errors)

def _extract_coordinates(self, row, colnames, fields):
def _extract_coordinates(self, row, colnames, coord_fields):
coords = []
for field in fields:
for field in coord_fields:
# if no header, we expect index for the columns has been provided
try:
field_index = (
colnames.index(field)
if self.settings.get("use_header")
else int(field)
)
except ValueError as err:
msg = f"{field} is not a valid coordinate field"
if not self.report:
self.report = SourceReporting(started=timezone.now())
self.report.status = SourceReporting.Status.WARNING.value
self.report.errors.append(msg)
self.report.save()
err.args = (msg,)
raise
except ValueError:
raise CSVSourceException(f"{field} is not a valid coordinate field")

c = row[field_index]
coords.append(c)
if len(coords) == 2:
Expand Down Expand Up @@ -674,14 +674,8 @@ def _get_srid(self):
coordinate_reference_system = self.settings["coordinate_reference_system"]
try:
return int(coordinate_reference_system.split("_")[1])
except (IndexError, ValueError) as err:
msg = f"Invalid SRID: {coordinate_reference_system}"
self.report.status = SourceReporting.Status.ERROR.value
self.report.message = msg
# self.report.setdefault("message", []).append(msg)
self.report.save()
err.args = (msg,)
raise
except (IndexError, ValueError):
raise CSVSourceException(f"Invalid SRID: {coordinate_reference_system}")

# properties are use by serializer for representation (reading operation)
@property
Expand Down
4 changes: 2 additions & 2 deletions project/geosource/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _validate_field_infos(self, data):
except TypeError:
return # file field is empty in update no get_records
try:
records = instance._get_records(1)
records, errors = instance._get_records(1)
except Exception as err:
raise ValidationError(err.args[0])

Expand Down Expand Up @@ -439,7 +439,7 @@ def _validate_field_infos(self, data):
# create an instance without saving data
instance = self.Meta.model(**data_copy)
try:
records = instance._get_records(1)
records, errors = instance._get_records(1)
except (ValueError, GDALException) as err:
raise ValidationError(err.args[0])

Expand Down
2 changes: 1 addition & 1 deletion project/geosource/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run_model_object_method(self, app, model, pk, method, success_state=states.S

self.update_state(state=success_state, meta=state)
if obj and hasattr(obj, "report"):
if not obj.report:
if obj.report is not None:
text = ""
for key, value in state.items():
text += f"{key}: {value},"
Expand Down
9 changes: 5 additions & 4 deletions project/geosource/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,10 @@ def run_sync_method_result(cmd, success_state):
@patch(
"project.geosource.models.Source._get_records",
MagicMock(
return_value=[
{"a": "b", "c": 42, "d": b"4", "e": b"\xe8", "_geom_": "POINT(0 0)"}
]
return_value=(
[{"a": "b", "c": 42, "d": b"4", "e": b"\xe8", "_geom_": "POINT(0 0)"}],
[],
)
),
)
def test_update_fields_method(self):
Expand All @@ -283,7 +284,7 @@ def test_update_fields_method(self):

@patch(
"project.geosource.models.Source._get_records",
MagicMock(return_value=[{"a": "b", "c": 42, "_geom_": "POINT(0 0)"}]),
MagicMock(return_value=([{"a": "b", "c": 42, "_geom_": "POINT(0 0)"}], [])),
)
def test_update_fields_with_delete_method(self):
obj = Source.objects.create(geom_type=10)
Expand Down
Loading

0 comments on commit 9e7ebdf

Please sign in to comment.