Skip to content

Commit

Permalink
remove extra irrelevant stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Feb 14, 2024
1 parent 9d2df1a commit a16448c
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def num_channels(self):
return len(self.channels)

def __getitem__(self, roi: Roi) -> np.ndarray:
logger.info(f"Concat Array: Get Item {self.name} {roi}")
default = (
np.zeros_like(self.source_array[roi])
if self.default_array is None
Expand Down
41 changes: 8 additions & 33 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model
import logging
import torch

logger = logging.getLogger(__file__)
import torch


class Run:
Expand Down Expand Up @@ -55,37 +53,14 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
97 changes: 18 additions & 79 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,94 +3,33 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]

# Hack
# if label is mito_peroxisome or peroxisome then change it to mito
mitos = ["mito_proxisome", "peroxisome"]


def match_heads(model, head_weights, old_head, new_head):
# match the heads
for label in new_head:
old_label = label
if label in mitos:
old_label = "mito"
if old_label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(old_label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in head_keys:
if key in model.state_dict().keys():
n_val = head_weights[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label} with {old_label}")


class Start(ABC):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.warning(
f"loading weights from run {self.run}, criterion: {self.criterion}"
)

logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")
# load the model weights (taken from torch load_state_dict source)
try:
if self.old_head and self.new_head:
try:
self.load_model_using_head_matching(model, weights)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
self.load_model_using_head_removal(model, weights)
elif self.remove_head:
self.load_model_using_head_removal(model, weights)
else:
model.load_state_dict(weights.model)
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")

def load_model_using_head_removal(self, model, weights):
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
for key in head_keys:
weights.model.pop(key, None)
logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}")
model.load_state_dict(weights.model, strict=False)
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)

def load_model_using_head_matching(self, model, weights):
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.warning(f"old head: {self.old_head}")
logger.warning(f"new head: {self.new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, self.old_head, self.new_head)
logger.warning(e)
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
1 change: 0 additions & 1 deletion dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, task_config):
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
extra_conv=task_config.extra_conv,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
7 changes: 0 additions & 7 deletions dacapo/experiments/tasks/distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,3 @@ class DistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)

extra_conv: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to add an extra conv layer before the head"
},
)
56 changes: 9 additions & 47 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
extra_conv: bool,
):
def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -42,52 +36,20 @@ def __init__(
self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.extra_conv = extra_conv
self.extra_conv_dims = len(self.channels) * 2

@property
def embedding_dims(self):
return len(self.channels)

def create_model(self, architecture):
if self.extra_conv:
if architecture.dims == 2:
head = torch.nn.Sequential(
torch.nn.Conv2d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv2d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
elif architecture.dims == 3:
head = torch.nn.Sequential(
torch.nn.Conv3d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv3d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
else:
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)

Expand Down
20 changes: 4 additions & 16 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,11 @@ def __init__(self, trainer_config):
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)
self.finetune_head_only = trainer_config.finetune_head_only

self.scheduler = None

def create_optimizer(self, model):
if self.finetune_head_only:
logger.warning("Finetuning head only")
parameters = []
for name, param in model.named_parameters():
if "prediction_head" in name:
parameters.append(param)
else:
param.requires_grad = False
else:
parameters = model.parameters()
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters)
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters())
self.scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
Expand Down Expand Up @@ -228,15 +217,15 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
def iterate(self, num_iterations, model, optimizer, device):
t_start_fetch = time.time()

logger.info("Starting iteration!")

for iteration in range(self.iteration, self.iteration + num_iterations):
raw, gt, target, weight, mask = self.next()
logger.debug(
f"Trainer fetch batch took {time.time() - t_start_fetch} seconds"
)

for (
param
) in model.parameters(): # TODO: get parameters from optimizer instead
for param in model.parameters():
param.grad = None

t_start_prediction = time.time()
Expand All @@ -247,7 +236,6 @@ def iterate(self, num_iterations, model, optimizer, device):
torch.as_tensor(target[target.roi]).to(device).float(),
torch.as_tensor(weight[weight.roi]).to(device).float(),
)

loss.backward()
optimizer.step()

Expand Down
5 changes: 0 additions & 5 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,3 @@ class GunpowderTrainerConfig(TrainerConfig):
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)

finetune_head_only: Optional[bool] = attr.ib(
default=False,
metadata={"help_text": "Whether to fine-tune head only or all layers"},
)
10 changes: 2 additions & 8 deletions dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
logger = logging.getLogger(__name__)


def train(
run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda=False
):
def train(run_name: str, compute_context: ComputeContext = LocalTorch()):
"""Train a run"""

if compute_context.train(run_name):
Expand Down Expand Up @@ -104,10 +102,6 @@ def train_run(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}. "
)
# raise RuntimeError(
# f"Found weights for iteration {latest_weights_iteration}, but "
# f"run {run.name} was only trained until {trained_until}."
# )

# start/resume training

Expand Down Expand Up @@ -167,7 +161,7 @@ def train_run(

run.model.eval()
# free up optimizer memory to allow larger validation blocks
# run.model = run.model.to(torch.device("cpu"))
run.model = run.model.to(torch.device("cpu"))
run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True)

stats_store.store_training_stats(run.name, run.training_stats)
Expand Down
6 changes: 2 additions & 4 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def validate_run(
evaluator.set_best(run.validation_scores)

for validation_dataset in run.datasplit.validate:
logger.warning("Validating on dataset %s", validation_dataset.name)
assert (
validation_dataset.gt is not None
), "We do not yet support validating on datasets without ground truth"
Expand All @@ -99,7 +98,7 @@ def validate_run(
f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}"
).exists()
):
logger.warning("Copying validation inputs!")
logger.info("Copying validation inputs!")
input_voxel_size = validation_dataset.raw.voxel_size
output_voxel_size = run.model.scale(input_voxel_size)
input_shape = run.model.eval_input_shape
Expand Down Expand Up @@ -137,13 +136,12 @@ def validate_run(
)
input_gt[output_roi] = validation_dataset.gt[output_roi]
else:
logger.warning("validation inputs already copied!")
logger.info("validation inputs already copied!")

prediction_array_identifier = array_store.validation_prediction_array(
run.name, iteration, validation_dataset
)
logger.info("Predicting on dataset %s", validation_dataset.name)

predict(
run.model,
validation_dataset.raw,
Expand Down

0 comments on commit a16448c

Please sign in to comment.