Skip to content

Commit

Permalink
⚡️ Use manual savepoints, add report exception
Browse files Browse the repository at this point in the history
  • Loading branch information
LePetitTim committed Apr 12, 2023
1 parent 6fe927b commit 3a162b2
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 51 deletions.
41 changes: 25 additions & 16 deletions project/geosource/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,40 @@ def _refresh_data(self, es_index=None):
geometry = row.pop(self.SOURCE_GEOM_ATTRIBUTE)
try:
identifier = row[self.id_field]
sid = transaction.savepoint()
try:
feature = self.update_feature(layer, identifier, geometry, row)
if es_index and feature:
es_index.index_feature(layer, feature)
transaction.savepoint_commit(sid)
except Exception as exc:
transaction.savepoint_rollback(sid)
msg = "An error occured on feature(s)"
report["status"] = "Warning"
report.setdefault("message", []).append(msg)
report.setdefault("lines", {})[
f"Error on row n°{i} {self.id_field} : {identifier}"
] = str(exc)
continue
except KeyError:
msg = "Can't find identifier field for this record"
report["status"] = "Warning"
report.setdefault("message", []).append(msg)
report.setdefault("lines", {}).setdefault(f"{i}", []).append(msg)
report.setdefault("lines", {})[i] = msg
continue
with transaction.atomic(savepoint=False):
try:
feature = self.update_feature(layer, identifier, geometry, row)
if es_index and feature:
es_index.index_feature(layer, feature)
except Exception:
pass
row_count += 1
self.clear_features(layer, begin_date)

self.report = report
if not row_count:
self.report["status"] = "Error"
self.save(update_fields=["report"])
raise Exception("Failed to refresh data")

if row_count == total:
self.report["status"] = "SUCCESS"
self.save(update_fields=["report"])
if not self.report:
if not row_count:
self.report["status"] = "Error"
self.save(update_fields=["report"])
raise Exception("Failed to refresh data")

if row_count == total:
self.report["status"] = "SUCCESS"
self.save(update_fields=["report"])
return {"count": row_count, "total": total}

@transaction.atomic
Expand Down
11 changes: 6 additions & 5 deletions project/geosource/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ def run_model_object_method(self, app, model, pk, method, success_state=states.S

self.update_state(state=success_state, meta=state)
if obj and hasattr(obj, "report"):
text = ""
for key, value in state.items():
text += f"{key}: {value},"
obj.report["lines"] = text
obj.save(update_fields=["report"])
if not obj.report:
text = ""
for key, value in state.items():
text += f"{key}: {value},"
obj.report["lines"] = text
obj.save(update_fields=["report"])

except Model.DoesNotExist:
set_failure_state(
Expand Down
21 changes: 9 additions & 12 deletions project/geosource/tests/test_model_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_coordinates_systems_malformed_raise_index_error(
with self.assertRaises(IndexError):
source._get_records()

def test_invalid_id_field_raise_value_error_when_refreshing_data(
def test_invalid_id_field_report_message_when_refreshing_data(
self, mocked_es_delete, mocked_es_create
):
source = CSVSource.objects.create(
Expand All @@ -189,15 +189,14 @@ def test_invalid_id_field_raise_value_error_when_refreshing_data(
},
)
msg = "Can't find identifier field for this record"
with self.assertRaisesMessage(Exception, "Failed to refresh data"):
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))


@patch("elasticsearch.client.IndicesClient.create")
@patch("elasticsearch.client.IndicesClient.delete")
class GeoJSONSourceExceptionsTestCase(TestCase):
def test_source_geojson_with_wrong_id_raise_value_error(
def test_source_geojson_with_wrong_id_report_message(
self, mocked_es_delete, mocked_es_create
):
geodict = {
Expand All @@ -221,9 +220,8 @@ def test_source_geojson_with_wrong_id_raise_value_error(
id_field="gid", # wrong id field
)
msg = "Can't find identifier field for this record"
with self.assertRaisesMessage(Exception, "Failed to refresh data"):
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))


@patch("elasticsearch.client.IndicesClient.create")
Expand All @@ -250,7 +248,7 @@ def test_operationalerror_on_db_connect_is_reported(
@patch("elasticsearch.client.IndicesClient.create")
@patch("elasticsearch.client.IndicesClient.delete")
class ShapefileSourceExceptionsTestCase(TestCase):
def test_wrong_id_raise_exception_on_refresh_and_get_reported(
def test_wrong_id_report_message_on_refresh_and_get_reported(
self, mocked_es_delete, mocked_es_create
):
source = ShapefileSource.objects.create(
Expand All @@ -260,6 +258,5 @@ def test_wrong_id_raise_exception_on_refresh_and_get_reported(
id_field="wrongid",
)
msg = "Can't find identifier field for this record"
with self.assertRaisesMessage(Exception, "Failed to refresh data"):
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))
source.refresh_data()
self.assertIn(msg, source.report.get("message", []))
53 changes: 39 additions & 14 deletions project/geosource/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def test_geojsonsource_type(self):
def test_wrong_identifier_refresh(self, mocked_es_delete, mocked_es_create):
self.geojson_source.id_field = "wrong_identifier"
self.geojson_source.save()
with self.assertRaisesRegexp(Exception, "Failed to refresh data"):
self.geojson_source.refresh_data()
self.geojson_source.refresh_data()
msg = "Can't find identifier field for this record"
self.assertIn(msg, self.geojson_source.report.get("message", []))

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_delete(self, mock_index):
Expand Down Expand Up @@ -276,9 +277,13 @@ def setUpTestData(cls):
}

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_two_columns_coordinates(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_two_columns_coordinates(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="source",
file=get_file("source.csv"),
geom_type=GeometryTypes.Point,
id_field="ID",
Expand All @@ -292,14 +297,18 @@ def test_get_records_with_two_columns_coordinates(self, mock_index):

records = source._get_records()
self.assertEqual(len(records), 6, len(records))
with self.assertNumQueries(54):
with self.assertNumQueries(66):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_one_column_coordinates(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_one_column_coordinates(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="source_xy",
file=get_file("source_xy.csv"),
geom_type=GeometryTypes.Point,
id_field="ID",
Expand All @@ -314,12 +323,15 @@ def test_get_records_with_one_column_coordinates(self, mock_index):

records = source._get_records()
self.assertEqual(len(records), 9, len(records))
with self.assertNumQueries(72):
with self.assertNumQueries(90):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_decimal_separator_as_comma(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_decimal_separator_as_comma(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="csv-source",
Expand All @@ -341,14 +353,18 @@ def test_get_records_with_decimal_separator_as_comma(self, mock_index):
)
records = source._get_records()
self.assertEqual(len(records), 9, len(records))
with self.assertNumQueries(72):
with self.assertNumQueries(90):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_nulled_columns_ignored(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_nulled_columns_ignored(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="source",
file=get_file("source.csv"),
geom_type=GeometryTypes.Point,
id_field="ID",
Expand All @@ -368,14 +384,18 @@ def test_get_records_with_nulled_columns_ignored(self, mock_index):
if record.get("photoEtablissement")
]
self.assertEqual(len(empty_entry), 0, empty_entry)
with self.assertNumQueries(54):
with self.assertNumQueries(66):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_no_header_and_yx_csv(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_no_header_and_yx_csv(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="source_xy_noheader",
file=get_file("source_xy_noheader.csv"),
geom_type=GeometryTypes.Point,
id_field="1",
Expand All @@ -390,14 +410,18 @@ def test_get_records_with_no_header_and_yx_csv(self, mock_index):
)
records = source._get_records()
self.assertEqual(len(records), 9, len(records))
with self.assertNumQueries(72):
with self.assertNumQueries(90):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index")
def test_get_records_with_no_header_and_two_columns_csv(self, mock_index):
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
def test_get_records_with_no_header_and_two_columns_csv(
self, mock_index_feature, mock_index
):
mock_index.return_value = True
source = CSVSource.objects.create(
name="source_noheader",
file=get_file("source_noheader.csv"),
geom_type=0,
id_field="2",
Expand All @@ -412,12 +436,13 @@ def test_get_records_with_no_header_and_two_columns_csv(self, mock_index):
)
records = source._get_records()
self.assertEqual(len(records), 10, len(records))
with self.assertNumQueries(78):
with self.assertNumQueries(98):
row_count = source.refresh_data()
self.assertEqual(row_count["count"], len(records), row_count)

def test_update_fields_keep_order(self):
source = CSVSource.objects.create(
name="source",
file=get_file("source.csv"),
geom_type=GeometryTypes.Point,
id_field="ID",
Expand Down
13 changes: 9 additions & 4 deletions project/geosource/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@mock.patch("elasticsearch.client.IndicesClient.create")
@mock.patch("elasticsearch.client.IndicesClient.delete")
@mock.patch("project.geosource.elasticsearch.index.LayerESIndex.index_feature")
class TaskTestCase(TestCase):
@classmethod
def setUpTestData(cls):
Expand All @@ -23,7 +24,9 @@ def setUpTestData(cls):
settings={"groups": [cls.group.pk]},
)

def test_task_refresh_data_method(self, mocked_es_delete, mocked_es_create):
def test_task_refresh_data_method(
self, mocked_index_feature, mocked_es_delete, mocked_es_create
):
run_model_object_method.apply(
(
self.element._meta.app_label,
Expand All @@ -38,7 +41,7 @@ def test_task_refresh_data_method(self, mocked_es_delete, mocked_es_create):
self.assertEqual(Layer.objects.first().authorized_groups.first().name, "Group")

def test_task_refresh_data_method_wrong_pk(
self, mocked_es_delete, mocked_es_create
self, mocked_index_feature, mocked_es_delete, mocked_es_create
):
logging.disable(logging.WARNING)
run_model_object_method.apply(
Expand All @@ -51,7 +54,9 @@ def test_task_refresh_data_method_wrong_pk(
)
self.assertEqual(Layer.objects.count(), 0)

def test_task_wrong_method(self, mocked_es_delete, mocked_es_create):
def test_task_wrong_method(
self, mocked_index_feature, mocked_es_delete, mocked_es_create
):
logging.disable(logging.ERROR)
run_model_object_method.apply(
(
Expand All @@ -65,7 +70,7 @@ def test_task_wrong_method(self, mocked_es_delete, mocked_es_create):

@mock.patch("project.geosource.models.Source.objects")
def test_task_good_method_error(
self, mocked_es_delete, mocked_es_create, mock_source
self, mock_source, mocked_index_feature, mocked_es_delete, mocked_es_create
):
mock_source.get.side_effect = ValueError
logging.disable(logging.ERROR)
Expand Down

0 comments on commit 3a162b2

Please sign in to comment.