Skip to content

Commit

Permalink
Merge pull request #294 from AllenCellModeling/feat/provide_metriclist
Browse files Browse the repository at this point in the history
Feat/provide metriclist
  • Loading branch information
ritvikvasan authored Sep 28, 2023
2 parents a52a181 + ceb3e3f commit e1d1a77
Show file tree
Hide file tree
Showing 29 changed files with 170 additions and 110 deletions.
1 change: 0 additions & 1 deletion cyto_dl/callbacks/callback_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _on_epoch_end(self, split, trainer, pl_module):


def log_artifacts(outputs, prefix, current_epoch, latent_dim):

input_key = [i for i in outputs[0].keys() if "z_parts_params" in i][0]
input_key = input_key.split("/")[1]
_bs = len(outputs[0][f"z_parts_params/{input_key}"])
Expand Down
6 changes: 1 addition & 5 deletions cyto_dl/callbacks/latent_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def get_ranked_dims(
cutoff_kld_per_dim,
max_num_shapemodes,
):

stats = (
stats.loc[stats["test_kld_per_latent_dim"] > cutoff_kld_per_dim]
.sort_values(by=["test_kld_per_latent_dim"])
Expand Down Expand Up @@ -96,9 +95,7 @@ def __init__(
self.plot_limits = [-120, 120, -140, 140]

def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

with torch.no_grad():

with tempfile.TemporaryDirectory() as tmp_dir:
client = mlflow.tracking.MlflowClient(mlflow.get_tracking_uri())
stats = pd.read_csv(
Expand Down Expand Up @@ -158,7 +155,6 @@ def compute_projections(
compute_features: bool,
x_label: str,
):

matplotlib.rc("xtick", labelsize=3)
matplotlib.rc("ytick", labelsize=3)
matplotlib.rcParams["xtick.major.size"] = 0.1
Expand Down Expand Up @@ -319,6 +315,6 @@ def get_surface_area(input_img):
dy = np.array([0, 0, 1, 0, -1, 0])
dz = np.array([-1, 0, 0, 0, 0, 1])
surface_area = 0
for (k, j, i) in zip(pxl_z, pxl_y, pxl_x):
for k, j, i in zip(pxl_z, pxl_y, pxl_x):
surface_area += 6 - input_img_surface[k + dz, j + dy, i + dx].sum()
return int(surface_area)
3 changes: 0 additions & 3 deletions cyto_dl/callbacks/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(
self.cutoff_kld_per_dim = 0

def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):

with torch.no_grad():

embeddings = get_all_embeddings(
trainer.datamodule.train_dataloader(),
trainer.datamodule.val_dataloader(),
Expand All @@ -79,7 +77,6 @@ def get_all_embeddings(
x_label: str,
id_label: None,
):

all_embeddings = []
cell_ids = []
split = []
Expand Down
6 changes: 5 additions & 1 deletion cyto_dl/callbacks/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def __init__(self, n_epochs, layer_names, save_dir):
pd.DataFrame(init_activation_csv).to_csv(self.save_dir / "activations.csv", index=False)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["outlier_detection"] = {"md_cov": self.cov, "md_mu": self.mu, "md_n": self.n}
checkpoint["outlier_detection"] = {
"md_cov": self.cov,
"md_mu": self.mu,
"md_n": self.n,
}

def on_load_checkpoint(self, trainer, pl_module, checkpoint):
od_params = checkpoint.get("outlier_detection", {})
Expand Down
1 change: 0 additions & 1 deletion cyto_dl/dataframe/transforms/group_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def _make_group(self, k, v, row):

def __call__(self, row):
res = {}

for k, v in self.groups.items():
if v is None:
res[k] = row[k]
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/datamodules/data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def make_data_dict_dataloader(
data: Sequence[Union[DictConfig, dict]],
transforms: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]] = None,
**dataloader_kwargs
**dataloader_kwargs,
):
"""Create a dataloader based on a dictionary of paths to images.
Expand Down
1 change: 0 additions & 1 deletion cyto_dl/datamodules/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def _sampler_generator(self):

def __iter__(self) -> Iterator[List[int]]:
for sampler_ix in self.sampler_order:

try:
yield [next(self.sampler_iterators[sampler_ix]) for _ in range(self.batch_size)]
except StopIteration:
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/datamodules/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def make_folder_dataloader(
endswith: Optional[str] = None,
contains: Optional[str] = None,
excludes: Optional[str] = None,
**dataloader_kwargs
**dataloader_kwargs,
):
"""Create a dataloader based on a folder of samples. If no transforms are applied, each sample
is a dictionary with a key "input" containing the corresponding path and a key "orig_fname"
Expand Down
3 changes: 2 additions & 1 deletion cyto_dl/datamodules/smartcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def setup(self, stage=None):
self.img_data["train"] = self.get_per_file_args(self.df_train)
self.img_data["val"] = self.get_per_file_args(self.df_val)
pd.DataFrame(self.img_data["train"]).to_csv(
f"{self.csv_path.parents[0]}/loaded_data/train_img_data.csv", index=False
f"{self.csv_path.parents[0]}/loaded_data/train_img_data.csv",
index=False,
)
pd.DataFrame(self.img_data["val"]).to_csv(
f"{self.csv_path.parents[0]}/loaded_data/val_img_data.csv", index=False
Expand Down
1 change: 0 additions & 1 deletion cyto_dl/image/transforms/multiscale_cropper.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def generate_slices(self, image_dict: Dict) -> Dict:
}

def __call__(self, image_dict):

available_keys = self.keys
if self.allow_missing_keys:
available_keys = [k for k in self.keys if k in image_dict]
Expand Down
5 changes: 4 additions & 1 deletion cyto_dl/image/transforms/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ class MaxProjectd(Transform):
"""Monai-style transform to take max projection of an image."""

def __init__(
self, keys: Union[list, str], projection_dim: int = 1, allow_missing_keys: bool = False
self,
keys: Union[list, str],
projection_dim: int = 1,
allow_missing_keys: bool = False,
):
"""
Parameters
Expand Down
1 change: 1 addition & 0 deletions cyto_dl/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
metrics=_DEFAULT_METRICS,
):
super().__init__()

self.metrics = tuple(metrics.keys())

for key, value in metrics.items():
Expand Down
9 changes: 7 additions & 2 deletions cyto_dl/models/im2im/utils/omnipose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class OmniposePreprocessd(Transform):
smooth distance images from an input instance segmentation."""

def __init__(
self, label_keys: Union[Sequence[str], str], dim: int = 3, allow_missing_keys: bool = False
self,
label_keys: Union[Sequence[str], str],
dim: int = 3,
allow_missing_keys: bool = False,
):
"""
Parameters
Expand Down Expand Up @@ -376,7 +379,9 @@ def get_separated_masks(self, flow_crop, mask_crop, dist_crop, device, crop):
return {
"slice": tuple(crop[1:]),
"mask": remove_small_holes(
mask_crop > 0, area_threshold=self.hole_size, connectivity=self.spatial_dim
mask_crop > 0,
area_threshold=self.hole_size,
connectivity=self.spatial_dim,
),
}

Expand Down
7 changes: 4 additions & 3 deletions cyto_dl/models/im2im/utils/skoots.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ def kd_clustering(self, embed_z, embed_y, embed_x, skel):
skel = find_boundaries(skel, mode="inner") * skel # propagate labels
skel_points = np.stack(skel.nonzero()).T
embed_points = torch.stack((embed_z, embed_y, embed_x)).numpy()
dist_to_closest_skel, closest_skel_point_to_embedding = self._get_point_embeddings(
embed_points.T, skel_points
)
(
dist_to_closest_skel,
closest_skel_point_to_embedding,
) = self._get_point_embeddings(embed_points.T, skel_points)
embedding_labels = skel[
closest_skel_point_to_embedding[0],
closest_skel_point_to_embedding[1],
Expand Down
90 changes: 51 additions & 39 deletions cyto_dl/models/vae/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
reconstruction_loss: Loss = nn.MSELoss(reduction="none"),
prior: Optional[Sequence[Prior]] = None,
decoder_latent_parts: Optional[Dict[str, Sequence[str]]] = None,
disable_metrics: Optional[bool] = False,
metric_keys: Optional[list] = None,
**base_kwargs,
):
"""Instantiate a basic VAE model.
Expand All @@ -50,45 +52,55 @@ def __init__(
**base_kwargs:
Additional arguments passed to BaseModel
"""

_DEFAULT_METRICS = {
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
"train/loss/total_reconstruction": MeanMetric(),
"val/loss/total_reconstruction": MeanMetric(),
"test/loss/total_reconstruction": MeanMetric(),
"train/loss/total_kld": MeanMetric(),
"val/loss/total_kld": MeanMetric(),
"test/loss/total_kld": MeanMetric(),
}

if not isinstance(prior, (dict, DictConfig)):
prior = {"embedding": prior}

for part in prior.keys():
_DEFAULT_METRICS.update(
{
f"train/loss/kld_{part}": MeanMetric(),
f"val/loss/kld_{part}": MeanMetric(),
f"test/loss/kld_{part}": MeanMetric(),
}
)

if not isinstance(reconstruction_loss, (dict, DictConfig)):
assert x_label is not None
recon_parts = [x_label]
else:
recon_parts = reconstruction_loss.keys()

for part in recon_parts:
_DEFAULT_METRICS.update(
{
f"train/loss/reconstruction_{part}": MeanMetric(),
f"val/loss/reconstruction_{part}": MeanMetric(),
f"test/loss/reconstruction_{part}": MeanMetric(),
}
)
if not isinstance(prior, (dict, DictConfig)):
prior = {"embedding": prior}
if disable_metrics:
_DEFAULT_METRICS = {
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
}
elif metric_keys:
_DEFAULT_METRICS = {}
for key in metric_keys:
_DEFAULT_METRICS.update({key: MeanMetric()})

else:
_DEFAULT_METRICS = {
"train/loss": MeanMetric(),
"val/loss": MeanMetric(),
"test/loss": MeanMetric(),
"train/loss/total_reconstruction": MeanMetric(),
"val/loss/total_reconstruction": MeanMetric(),
"test/loss/total_reconstruction": MeanMetric(),
"train/loss/total_kld": MeanMetric(),
"val/loss/total_kld": MeanMetric(),
"test/loss/total_kld": MeanMetric(),
}

for part in prior.keys():
_DEFAULT_METRICS.update(
{
f"train/loss/kld_{part}": MeanMetric(),
f"val/loss/kld_{part}": MeanMetric(),
f"test/loss/kld_{part}": MeanMetric(),
}
)

for part in recon_parts:
_DEFAULT_METRICS.update(
{
f"train/loss/reconstruction_{part}": MeanMetric(),
f"val/loss/reconstruction_{part}": MeanMetric(),
f"test/loss/reconstruction_{part}": MeanMetric(),
}
)

metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS)

Expand All @@ -100,11 +112,11 @@ def __init__(
prior[key] = IsotropicGaussianPrior(dimensionality=latent_dim)
else:
prior[key] = IdentityPrior(dimensionality=latent_dim)
elif not isinstance(prior[key], Prior):
raise ValueError(
f"Expected prior to either be one of ('gaussian', 'identity', None)"
f"or an object of type `Prior`. Got: {type(prior)}"
)
# elif not isinstance(prior[key], Prior):
# raise ValueError(
# f"Expected prior to either be one of ('gaussian', 'identity', None)"
# f"or an object of type `Prior`. Got: {type(prior)}"
# )
self.prior = nn.ModuleDict(prior)

self.reconstruction_loss = reconstruction_loss
Expand Down Expand Up @@ -217,7 +229,7 @@ def decode(self, z):
# for each decoder key, get the latent parts it uses from `self.decoder_latent_keys`
# and pass them as *args to that decoder's forward method
return {
part: decoder(*[z[key] for key in self.decoder_latent_keys[part]])
part: decoder(*[z[key] for key in self.decoder.keys()])
for part, decoder in self.decoder.items()
}

Expand Down
6 changes: 5 additions & 1 deletion cyto_dl/models/vae/image_canon_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def __init__(
self.rotation_module = None

super().__init__(
encoder=encoder, decoder=decoder, latent_dim=latent_dim, prior=prior, **base_kwargs
encoder=encoder,
decoder=decoder,
latent_dim=latent_dim,
prior=prior,
**base_kwargs,
)

self.make_canon_net(maximum_frequency)
Expand Down
13 changes: 10 additions & 3 deletions cyto_dl/models/vae/image_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
first_conv_padding_mode: str = "replicate",
encoder_padding: Optional[Union[int, Sequence[int]]] = None,
eps: float = 1e-8,
**base_kwargs
**base_kwargs,
):
in_channels, *in_shape = in_shape

Expand Down Expand Up @@ -104,6 +104,9 @@ def __init__(
else:
self.mask = None

if encoder_padding is None:
encoder_padding = [None] * len(kernel_sizes)

for k, s, p in zip(kernel_sizes, strides, encoder_padding):
padding = same_padding(k) if p is None else p
self.final_size = calculate_out_shape(self.final_size, k, s, padding)
Expand Down Expand Up @@ -169,7 +172,7 @@ def __init__(
*decode_blocks,
# decoder,
last_act if last_act is not None else nn.Identity(),
_Scale(last_scale)
_Scale(last_scale),
)

if isinstance(prior, (str, type(None))):
Expand Down Expand Up @@ -200,7 +203,11 @@ def __init__(
self.rotation_module = None

super().__init__(
encoder=encoder, decoder=decoder, latent_dim=latent_dim, prior=prior, **base_kwargs
encoder=encoder,
decoder=decoder,
latent_dim=latent_dim,
prior=prior,
**base_kwargs,
)

def encode(self, batch):
Expand Down
1 change: 0 additions & 1 deletion cyto_dl/models/vae/o2_spharm_vae/o2_spharm_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
max_hidden_band=8,
grid_size=64,
):

super().__init__()
self.out_dim = out_dim
self.reflections = reflections
Expand Down
1 change: 0 additions & 1 deletion cyto_dl/models/vae/o2_spharm_vae/o2_spharm_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
)

def encode(self, batch):

# reorder x's columns to match the encoder
x = batch[self.hparams.x_label][:, self.flat_indices]

Expand Down
Loading

0 comments on commit e1d1a77

Please sign in to comment.