diff --git a/pepys_import/core/store/common_db.py b/pepys_import/core/store/common_db.py index 2569b5399..488c5df85 100644 --- a/pepys_import/core/store/common_db.py +++ b/pepys_import/core/store/common_db.py @@ -226,7 +226,7 @@ def get_sensor( class TaskMixin: @declared_attr def parent(self): - return relationship("Task", lazy="joined", join_depth=1, innerjoin=True, uselist=False) + return relationship("Task") @declared_attr def parent_name(self): diff --git a/pepys_import/core/store/data_store.py b/pepys_import/core/store/data_store.py index 803da9336..1662d9915 100644 --- a/pepys_import/core/store/data_store.py +++ b/pepys_import/core/store/data_store.py @@ -1849,7 +1849,7 @@ def merge_generic(self, table_name, id_list, master_id) -> bool: if table_name == constants.PLATFORM: self.merge_platforms(id_list, master_id, change_id) - if table_name in [constants.SENSOR, constants.DATAFILE]: + elif table_name in [constants.SENSOR, constants.DATAFILE]: self.merge_measurements(table_name, id_list, master_id, change_id) elif table_name in reference_table_names + [constants.TAG, constants.TASK]: self.merge_objects(table_name, id_list, master_id, change_id) diff --git a/tests/test_data_store_merge_operations.py b/tests/test_data_store_merge_operations.py index 876ba71ab..978044b94 100644 --- a/tests/test_data_store_merge_operations.py +++ b/tests/test_data_store_merge_operations.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from unittest import TestCase from uuid import UUID @@ -1383,6 +1383,73 @@ def test_merge_generic_privacies(self): .scalar() ) + def test_merge_generic_tasks(self): + Task = self.store.db_classes.Task + Geometry1 = self.store.db_classes.Geometry1 + GeometrySubType = self.store.db_classes.GeometrySubType + Participant = self.store.db_classes.Participant + with self.store.session_scope(): + start = datetime.now() + end = start + timedelta(seconds=100) + old_task = Task( + name="TEST TASK", + start=start, + end=end, + privacy_id=self.privacy_id, + ) + geo_type = self.store.db_classes.GeometryType(name="Test GeoType") + self.store.session.add(old_task) + self.store.session.add(geo_type) + self.store.session.flush() + geo_sub_type = GeometrySubType(name="Test GeoSubType", parent=geo_type.geo_type_id) + self.store.session.add(geo_sub_type) + self.store.session.flush() + geometry = Geometry1( + subject_platform_id=self.platform_2.platform_id, + _geometry=WKTElement("POINT(123456 123456)", srid=4326), + geo_type_id=geo_type.geo_type_id, + geo_sub_type_id=geo_sub_type.geo_sub_type_id, + source_id=self.file.datafile_id, + task_id=old_task.task_id, + ) + participant = Participant( + platform_id=self.platform_2.platform_id, + task_id=old_task.task_id, + privacy_id=self.privacy_id, + ) + self.store.session.add(geometry) + self.store.session.add(participant) + self.store.session.flush() + + new_task = Task( + name="NEW TASK", + start=start, + end=end, + privacy_id=self.privacy_id, + ) + self.store.session.add(new_task) + self.store.session.commit() + + # Assert that target task doesn't have any dependent objects + dependent_objs = list(dependent_objects(new_task)) + assert len(dependent_objs) == 0 + source_dependent_objs = list(dependent_objects(old_task)) + assert len(source_dependent_objs) == 2 # 1 Geometry, 1 Participant + + # Merge old_task to new_task + assert self.store.merge_generic(constants.TASK, [old_task.task_id], new_task.task_id) + + # Assert that target task has all dependent objects + dependent_objs = list(dependent_objects(new_task)) + assert len(dependent_objs) == 2 + source_dependent_objs = list(dependent_objects(old_task)) + assert len(source_dependent_objs) == 0 + + # Assert that merged task deleted + assert ( + not self.store.session.query(Task).filter(Task.task_id == old_task.task_id).scalar() + ) + def test_merge_generic_wrong_table_name(self): with self.store.session_scope(): assert (