Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Do not allow to select and render corrupted batch plans #1015

Merged
merged 20 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include package/tests/test_data/napari_measurements_profile.json
include package/tests/test_data/notebook/*.json
include package/tests/test_data/old_saves/*/*/*.json
include package/tests/test_data/sample_batch_output.xlsx
include package/tests/test_data/problematic_excel_batch.xlsx
include package/PartSeg/napari.yaml
include Readme.md
include changelog.md
Expand Down
2 changes: 1 addition & 1 deletion package/PartSeg/_roi_analysis/batch_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def prepare_calculation(self):

def _refresh_batch_list(self):
current_calc = str(self.calculation_choose.currentText())
new_list = ["<no calculation>", *sorted(self.settings.batch_plans.keys())]
new_list = ["<no calculation>", *sorted(n for n, p in self.settings.batch_plans.items() if not p.is_bad())]
try:
index = new_list.index(current_calc)
except ValueError:
Expand Down
46 changes: 35 additions & 11 deletions package/PartSeg/_roi_analysis/prepare_plan_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum

from qtpy.QtCore import Qt, Signal
from qtpy.QtGui import QIcon
from qtpy.QtWidgets import (
QAction,
QCheckBox,
Expand Down Expand Up @@ -64,6 +65,7 @@
from PartSegCore.analysis.save_functions import save_dict
from PartSegCore.io_utils import LoadPlanExcel, LoadPlanJson, SaveBase
from PartSegCore.universal_const import Units
from PartSegData import icons_dir

group_sheet = (
"QGroupBox {border: 1px solid gray; border-radius: 9px; margin-top: 0.5em;} "
Expand Down Expand Up @@ -916,8 +918,12 @@
else:
self.information.setText(item.pretty_print(AnalysisAlgorithmSelection))

def edit_plan(self):
plan = self.sender().plan_to_edit # type: CalculationPlan
def edit_plan(self, plan: CalculationPlan):
if plan.is_bad():
QMessageBox().warning(
self, "Cannot edit broken plan", f"Cannot edit broken plan. {plan.get_error_source()}"
)
return
self.calculation_plan = copy(plan)
self.plan.set_plan(self.calculation_plan)
self.mask_set.clear()
Expand All @@ -934,10 +940,12 @@

def __init__(self, parent=None, calculation_plan=None):
super().__init__(parent)
self.calculation_plan = calculation_plan
self.calculation_plan = None
self.header().close()
self.itemSelectionChanged.connect(self.set_path)
self.setContextMenuPolicy(Qt.CustomContextMenu)
if calculation_plan is not None:
self.set_plan(calculation_plan)

def restore_path(self, widget, path):
"""
Expand Down Expand Up @@ -971,6 +979,11 @@
self.set_plan(calculation_plan)

def set_plan(self, calculation_plan):
if calculation_plan is not None and calculation_plan.is_bad():
QMessageBox().warning(
self, "Cannot preview broken plan", f"Cannot preview broken plan. {calculation_plan.get_error_source()}"
)
return
self.calculation_plan = calculation_plan
self.setCurrentItem(self.topLevelItem(0))
self.update_view(True)
Expand Down Expand Up @@ -1065,7 +1078,7 @@
:type settings: Settings
"""

plan_to_edit_signal = Signal()
plan_to_edit_signal = Signal(object)

def __init__(self, settings: PartSettings):
super().__init__()
Expand All @@ -1092,7 +1105,6 @@
info_chose_layout.addWidget(self.plan_view)
info_layout.addLayout(info_chose_layout)
self.setLayout(info_layout)
self.calculate_plans.addItems(sorted(self.settings.batch_plans.keys()))
self.protect = False
self.plan_to_edit = None

Expand All @@ -1104,6 +1116,7 @@
self.import_plans_btn.clicked.connect(self.import_plans)
self.settings.batch_plans_changed.connect(self.update_plan_list)
self.plan_view.customContextMenuRequested.connect(self._context_menu)
self.update_plan_list()

def _context_menu(self, point):
item = self.plan_view.itemAt(point)
Expand Down Expand Up @@ -1146,18 +1159,24 @@
return None

def update_plan_list(self):
new_plan_list = sorted(self.settings.batch_plans.keys())
new_plan_list = sorted(self.settings.batch_plans.items(), key=lambda x: x[0])
if self.calculate_plans.currentItem() is not None:
text = str(self.calculate_plans.currentItem().text())
try:
index = new_plan_list.index(text)
index = [x[0] for x in new_plan_list].index(text)
except ValueError:
index = -1
else:
index = -1
self.protect = True
self.calculate_plans.clear()
self.calculate_plans.addItems(new_plan_list)

for name, plan in new_plan_list:
item = QListWidgetItem(name)
if plan.is_bad():
item.setIcon(QIcon(os.path.join(icons_dir, "task-reject.png")))
item.setToolTip(plan.get_error_source())
self.calculate_plans.addItem(item)
if index != -1:
self.calculate_plans.setCurrentRow(index)
self.protect = False
Expand Down Expand Up @@ -1190,7 +1209,11 @@
res = dial.get_result()
plans, err = res.load_class.load(res.load_location)
if err:
show_warning("Import error", f"error during importing, part of data were filtered. {err}")
error_str = "\n".join(err)
show_warning("Import error", f"error during importing, part of data were filtered. {error_str}")
if not plans:
show_warning("Import error", "No plans were imported")
return

Check warning on line 1216 in package/PartSeg/_roi_analysis/prepare_plan_widget.py

View check run for this annotation

Codecov / codecov/patch

package/PartSeg/_roi_analysis/prepare_plan_widget.py#L1212-L1216

Added lines #L1212 - L1216 were not covered by tests
choose = ImportDialog(plans, self.settings.batch_plans, PlanPreview, CalculationPlan)
if choose.exec_():
for original_name, final_name in choose.get_import_list():
Expand All @@ -1214,7 +1237,7 @@
return # pragma: no cover
if text in self.settings.batch_plans:
self.plan_to_edit = self.settings.batch_plans[text]
self.plan_to_edit_signal.emit()
self.plan_to_edit_signal.emit(self.plan_to_edit)

def plan_preview(self, text):
if self.protect:
Expand All @@ -1223,7 +1246,8 @@
if not text.strip():
return
plan = self.settings.batch_plans[text]
self.plan_view.set_plan(plan)
if not plan.is_bad():
self.plan_view.set_plan(plan)


class CalculatePlaner(QSplitter):
Expand Down
12 changes: 7 additions & 5 deletions package/PartSeg/common_gui/custom_load_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,22 @@ def __init__(
history: typing.Optional[typing.List[str]] = None,
):
super().__init__(load_register, caption, parent)
self.setOption(QFileDialog.DontUseNativeDialog, True)
self.setFileMode(QFileDialog.ExistingFile)
self.setAcceptMode(QFileDialog.AcceptOpen)
self.setOption(QFileDialog.Option.DontUseNativeDialog, True)
self.setFileMode(QFileDialog.FileMode.ExistingFile)
self.setAcceptMode(QFileDialog.AcceptMode.AcceptOpen)
self.files_list = []
self.setWindowTitle("Open File")
if history is not None:
history = self.history() + history
self.setHistory(history)

def accept(self):
selected_files = [x for x in self.selectedFiles() if self.fileMode == QFileDialog.Directory or isfile(x)]
selected_files = [
x for x in self.selectedFiles() if self.fileMode == QFileDialog.FileMode.Directory or isfile(x)
]
if not selected_files:
return
if len(selected_files) == 1 and self.fileMode != QFileDialog.Directory and isdir(selected_files[0]):
if len(selected_files) == 1 and self.fileMode != QFileDialog.FileMode.Directory and isdir(selected_files[0]):
super().accept()
return

Expand Down
2 changes: 1 addition & 1 deletion package/PartSeg/common_gui/searchable_list_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class SearchableListWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)

self.list_widget = QListWidget()
self.list_widget = QListWidget(self)

self.filter_widget = QLineEdit()
self.filter_widget.textChanged.connect(self.update_visible)
Expand Down
35 changes: 34 additions & 1 deletion package/PartSegCore/analysis/calculation_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import copy, deepcopy
from enum import Enum

import local_migrator
from local_migrator import register_class, rename_key
from pydantic import BaseModel as PydanticBaseModel

Expand Down Expand Up @@ -289,7 +290,7 @@
replace_node = 3 #:


@register_class(old_paths=["PartSeg.utils.analysis.calculation_plan.CalculationTree"])
@register_class(old_paths=["PartSeg.utils.analysis.calculation_plan.CalculationTree"], allow_errors_in_values=True)
class CalculationTree:
"""
Structure for describe calculation structure
Expand All @@ -314,6 +315,31 @@
def as_dict(self):
return {"operation": self.operation, "children": self.children}

def is_bad(self):
return any(el.is_bad() for el in self.children) or isinstance(self.operation, dict)

def get_error_source(self):
res = []
for el in self.children:
res.extend(el.get_error_source())
if isinstance(self.operation, dict):
res.extend(self.get_source_error_dict(self.operation))
return res

@classmethod
def get_source_error_dict(cls, dkt):
if not isinstance(dkt, dict) or "__error__" not in dkt:
return []

Check warning on line 332 in package/PartSegCore/analysis/calculation_plan.py

View check run for this annotation

Codecov / codecov/patch

package/PartSegCore/analysis/calculation_plan.py#L332

Added line #L332 was not covered by tests
if "not found in register" in dkt["__error__"]:
return [dkt["__error__"]]
if "__values__" not in dkt:
return []

Check warning on line 336 in package/PartSegCore/analysis/calculation_plan.py

View check run for this annotation

Codecov / codecov/patch

package/PartSegCore/analysis/calculation_plan.py#L336

Added line #L336 was not covered by tests
fields = local_migrator.check_for_errors_in_dkt_values(dkt["__values__"])
res = []
for field in fields:
res.extend(cls.get_source_error_dict(dkt["__values__"][field]))
return res


class NodeType(Enum):
"""Type of node in calculation"""
Expand Down Expand Up @@ -466,6 +492,7 @@
return f"FileCalculation(file_path={self.file_path}, calculation={self.calculation})"


@register_class(allow_errors_in_values=True)
class CalculationPlan:
"""
Clean description Calculation plan.
Expand Down Expand Up @@ -502,6 +529,12 @@
self.changes = []
self.current_node = None

def is_bad(self):
return self.execution_tree.is_bad()

def get_error_source(self):
return ", ".join(self.execution_tree.get_error_source())

def as_dict(self):
return {"tree": self.execution_tree, "name": self.name}

Expand Down
4 changes: 2 additions & 2 deletions package/PartSegCore/analysis/load_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
SegmentationType,
WrongFileTypeException,
check_segmentation_type,
load_matadata_part,
load_metadata_base,
load_metadata_part,
open_tar_file,
proxy_callback,
tar_to_buff,
Expand Down Expand Up @@ -382,7 +382,7 @@ def load(
step_changed: typing.Optional[typing.Callable[[int], typing.Any]] = None,
metadata: typing.Optional[dict] = None,
) -> typing.Tuple[dict, list]:
return load_matadata_part(load_locations[0])
return load_metadata_part(load_locations[0])

@classmethod
def get_name(cls) -> str:
Expand Down
26 changes: 23 additions & 3 deletions package/PartSegCore/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def load_metadata_base(data: typing.Union[str, Path]):
return decoded_data


def load_matadata_part(data: typing.Union[str, Path]) -> typing.Tuple[typing.Any, typing.List[typing.Tuple[str, dict]]]:
def load_metadata_part(data: typing.Union[str, Path]) -> typing.Tuple[typing.Any, typing.List[typing.Tuple[str, dict]]]:
"""
Load serialized data. Get valid entries.

Expand All @@ -220,13 +220,20 @@ def load_matadata_part(data: typing.Union[str, Path]) -> typing.Tuple[typing.Any
# TODO extract to function
data = load_metadata_base(data)
bad_key = []
if isinstance(data, typing.MutableMapping) and "__error__" in data:
bad_key.append(data)
data = {}
if isinstance(data, typing.MutableMapping) and not check_loaded_dict(data):
bad_key.extend((k, data.pop(k)) for k, v in list(data.items()) if not check_loaded_dict(v))
elif isinstance(data, ProfileDict) and not data.verify_data():
bad_key = data.pop_errors()
return data, bad_key


load_matadata_part = load_metadata_part
# backward compatibility


def find_problematic_entries(data: typing.Any) -> typing.List[typing.MutableMapping]:
"""
Find top nodes with ``"__error__"`` key. If node found
Expand Down Expand Up @@ -463,7 +470,17 @@ def load(
step_changed: typing.Optional[typing.Callable[[int], typing.Any]] = None,
metadata: typing.Optional[dict] = None,
):
return load_matadata_part(load_locations[0])
from PartSegCore.analysis.calculation_plan import CalculationPlan

res, err = load_metadata_part(load_locations[0])
res_dkt = {}
err_li = []
for key, value in res.items():
if isinstance(value, CalculationPlan) and value.is_bad():
err_li.append(f"Problem with load {value.name} because of {value.get_error_source()}")
else:
res_dkt[key] = value
return res_dkt, err + err_li

@classmethod
def get_name(cls) -> str:
Expand Down Expand Up @@ -496,7 +513,7 @@ def load(
index += 1

try:
data, err = load_matadata_part(data)
data, err = load_metadata_part(data)
data_list.append(data)
error_list.extend(err)
except ValueError: # pragma: no cover
Expand All @@ -505,6 +522,9 @@ def load(
xlsx.close()
data_dict = {}
for calc_plan in data_list:
if calc_plan.is_bad():
error_list.append(f"Problem with load {calc_plan.name} because of {calc_plan.get_error_source()}")
continue
new_name = iterate_names(calc_plan.name, data_dict)
if new_name is None: # pragma: no cover
error_list.append(f"Cannot determine proper name for {calc_plan.name}")
Expand Down
15 changes: 15 additions & 0 deletions package/PartSegCore/json_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@

from PartSegCore._old_json_hooks import part_hook

PLUGINS_STR = "plugins."
PLUGINS_STR_LEN = len(PLUGINS_STR)


class PartSegEncoder(local_migrator.Encoder):
pass


def _validate_plugin_prefix(dkt: dict):
if dkt["__class__"].startswith(PLUGINS_STR):
# workaround for plans exported from an old PartSeg bundle
dkt["__class__"] = dkt["__class__"][PLUGINS_STR_LEN:]
if "__class_version_dkt__" in dkt:
for name, value in list(dkt["__class_version_dkt__"].items()):
if name.startswith(PLUGINS_STR):
dkt["__class_version_dkt__"][name[PLUGINS_STR_LEN:]] = value
del dkt["__class_version_dkt__"][name]


def partseg_object_hook(dkt: dict):
if "__class__" in dkt:
_validate_plugin_prefix(dkt)
return local_migrator.object_hook(dkt)

if "__ReadOnly__" in dkt or "__Serializable__" in dkt or "__Enum__" in dkt:
Expand Down
Loading
Loading