Skip to content

Commit

Permalink
include weighting argument for affinities+lsd loss
Browse files Browse the repository at this point in the history
  • Loading branch information
davidackerman committed Feb 9, 2024
1 parent 55a3892 commit 58c7abe
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 24 deletions.
8 changes: 2 additions & 6 deletions dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
num_voxels=task_config.num_voxels,
downsample_lsds=task_config.downsample_lsds,
grow_boundary_iterations=task_config.grow_boundary_iterations,
neighborhood=task_config.neighborhood, lsds=task_config.lsds
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.loss = AffinitiesLoss(len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio)
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
self.evaluator = InstanceEvaluator()
18 changes: 2 additions & 16 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,9 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
num_voxels: int = attr.ib(
default=20,
metadata={
"help_text": "The number of voxels to use for the gaussian sigma when computing lsds."
},
)
downsample_lsds: int = attr.ib(
lsds_to_affs_weight_ratio: float = attr.ib(
default=1,
metadata={
"help_text": "The amount to downsample the lsds. "
"This is useful for speeding up training and inference."
},
)
grow_boundary_iterations: int = attr.ib(
default=0,
metadata={
"help_text": "The number of iterations to run the grow boundaries algorithm. "
"This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects."
"help_text": "If training with lsds, set how much they should be weighted compared to affs."
},
)
5 changes: 3 additions & 2 deletions dacapo/experiments/tasks/losses/affinities_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class AffinitiesLoss(Loss):
def __init__(self, num_affinities: int):
def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float):
self.num_affinities = num_affinities
self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio

def compute(self, prediction, target, weight):
affs, affs_target, affs_weight = (
Expand All @@ -21,7 +22,7 @@ def compute(self, prediction, target, weight):
return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + (
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()

0 comments on commit 58c7abe

Please sign in to comment.