Skip to content

Commit

Permalink
Merge pull request #42 from janelia-cellmap/actions/black
Browse files Browse the repository at this point in the history
Format Python code with psf/black push
  • Loading branch information
mzouink authored Feb 9, 2024
2 parents 58c7abe + 353b8cb commit 4a54e31
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
27 changes: 18 additions & 9 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)

self.scheduler = None

def create_optimizer(self, model):
Expand Down Expand Up @@ -146,13 +151,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)
if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)
Expand All @@ -162,11 +168,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
mask_key=mask_key,
)

pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)
if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
Expand Down
7 changes: 7 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)

0 comments on commit 4a54e31

Please sign in to comment.