diff --git a/configs/data/im2im/segmentation_plugin.yaml b/configs/data/im2im/segmentation_plugin.yaml index 56536f13..26bbb7ef 100644 --- a/configs/data/im2im/segmentation_plugin.yaml +++ b/configs/data/im2im/segmentation_plugin.yaml @@ -74,12 +74,12 @@ transforms: - ${target_col1} - ${target_col2} base_image_key: ${base_image_col} - output_name: target + output_name: seg - _target_: monai.transforms.ToTensord keys: - ${source_col} - - target + - seg - ${exclude_mask_col} dtype: float16 @@ -87,7 +87,7 @@ transforms: - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd keys: - ${source_col} - - target + - seg - ${exclude_mask_col} patch_shape: ${data._aux.patch_shape} patch_per_image: ${data._aux.patch_per_image} @@ -143,12 +143,12 @@ transforms: - ${target_col1} - ${target_col2} base_image_key: ${base_image_col} - output_name: target + output_name: seg - _target_: monai.transforms.ToTensord keys: - ${source_col} - - target + - seg - ${exclude_mask_col} dtype: float16 @@ -232,12 +232,12 @@ transforms: - ${target_col1} - ${target_col2} base_image_key: ${base_image_col} - output_name: target + output_name: seg - _target_: monai.transforms.ToTensord keys: - ${source_col} - - target + - seg - ${exclude_mask_col} dtype: float16 @@ -245,7 +245,7 @@ transforms: - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd keys: - ${source_col} - - target + - seg - ${exclude_mask_col} patch_shape: ${data._aux.patch_shape} patch_per_image: ${data._aux.patch_per_image} @@ -254,7 +254,7 @@ transforms: _aux: patch_per_image: 1 _scales_dict: - - - target + - - seg - [1] - - ${source_col} - [1] diff --git a/configs/model/im2im/segmentation_plugin.yaml b/configs/model/im2im/segmentation_plugin.yaml index c6c35a2d..8c4211c9 100644 --- a/configs/model/im2im/segmentation_plugin.yaml +++ b/configs/model/im2im/segmentation_plugin.yaml @@ -18,7 +18,7 @@ backbone: filters: ${model._aux.filters} task_heads: - target: + seg: _target_: cyto_dl.nn.head.MaskHead mask_key: ${exclude_mask_col} loss: diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index 00b891f8..9ac05c05 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -58,6 +58,12 @@ def __init__( } for head in task_heads.keys(): + assert head not in ( + "loss", + "pred", + "target", + "input", + ), "Task head name cannot be 'loss', 'pred', 'target', or 'input'" _DEFAULT_METRICS.update( { f"train/loss/{head}": MeanMetric(),