Skip to content

Commit

Permalink
Fixed small bugs in the training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
JLrumberger committed Nov 20, 2024
1 parent d7b5c8b commit 02d937b
Show file tree
Hide file tree
Showing 8 changed files with 501 additions and 190 deletions.
57 changes: 8 additions & 49 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from skimage.util.shape import view_as_windows
import nimbus_inference
from nimbus_inference.utils import (prepare_normalization_dict,
predict_fovs, nimbus_preprocess, MultiplexDataset
predict_fovs, MultiplexDataset
)
from huggingface_hub import hf_hub_download, list_repo_files
import re
Expand All @@ -18,47 +18,6 @@
import re


def nimbus_preprocess(image, **kwargs):
"""Preprocess input data for Nimbus model.
Args:
image (np.array): array to be processed
**kwargs: keyword arguments for preprocessing:
{normalize (bool): whether to normalize the image,
marker (str): name of marker,
normalization_dict (dict): normalization dictionary,
clip_values (tuple): min/max values to clip the image to after normalization}
Returns:
np.array: processed image array
"""
output = np.copy(image)
if len(image.shape) != 4:
raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape))

normalize = kwargs.get("normalize", True)
if normalize:
marker = kwargs.get("marker", None)
if re.search(".tif|.tiff|.png|.jpg|.jpeg", marker, re.IGNORECASE):
marker = marker.split(".")[0]
normalization_dict = kwargs.get("normalization_dict", {})
if marker in normalization_dict.keys():
norm_factor = normalization_dict[marker]
else:
print(
"Norm_factor not found for marker {}, calculating directly from the image. \
".format(
marker
)
)
norm_factor = np.quantile(output[:,0,...], 0.999)
# normalize only marker channel in chan 0 not binary mask in chan 1
output[:,0,...] /= norm_factor
clip_values = kwargs.get("clip_values", False)
if clip_values:
output[:,0,...] = np.clip(output[:,0,...], clip_values[0], clip_values[1])
return output


def prep_naming_convention(deepcell_output_dir):
"""Prepares the naming convention for the segmentation data produced with the DeepCell library.
Expand Down Expand Up @@ -128,6 +87,7 @@ def __init__(
else:
misc_utils.verify_in_list(device=[device], valid_devices=["cpu", "cuda", "mps"])
self.device = torch.device(device)
self.load_checkpoint(padding="reflect")

def check_inputs(self):
"""check inputs for Nimbus model"""
Expand All @@ -149,8 +109,8 @@ def list_checkpoints(self, padding="reflect"):
path = Path(path).resolve()
local_dir = os.path.join(path, "assets")
os.makedirs(local_dir, exist_ok=True)
version_pattern = re.compile(r'V(\d+)\.pt')
local_checkpoints = [f for f in os.listdir(local_dir) if version_pattern.search(f)]
pattern = re.compile(r'.*\.pt$')
local_checkpoints = [f for f in os.listdir(local_dir) if pattern.search(f)]
return local_checkpoints

def load_local_checkpoint(self, checkpoint, padding="reflect"):
Expand All @@ -164,8 +124,8 @@ def load_local_checkpoint(self, checkpoint, padding="reflect"):
path = Path(path).resolve()
local_dir = os.path.join(path, "assets")
os.makedirs(local_dir, exist_ok=True)
version_pattern = re.compile(r'V(\d+)\.pt')
local_checkpoints = [f for f in os.listdir(local_dir) if version_pattern.search(f)]
pattern = re.compile(r'.*\.pt$')
local_checkpoints = [f for f in os.listdir(local_dir) if pattern.search(f)]
if checkpoint not in local_checkpoints:
raise ValueError(
f"Checkpoint {checkpoint} not found in local checkpoints {local_checkpoints}"
Expand Down Expand Up @@ -275,17 +235,16 @@ def predict_fovs(self):
self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False)
return self.cell_table

def predict_segmentation(self, input_data, preprocess_kwargs):
def predict_segmentation(self, input_data):
"""Predicts segmentation for input data.
Args:
input_data (np.array): Input data to predict segmentation for.
input_data (np.array): Normalized and clipped input data to predict segmentation for.
preprocess_kwargs (dict): Keyword arguments for preprocessing.
batch_size (int): Batch size for prediction.
Returns:
np.array: Predicted segmentation.
"""
input_data = nimbus_preprocess(input_data, **preprocess_kwargs)
if np.all(np.greater_equal(self.input_shape, input_data.shape[-2:])):
if not hasattr(self, "model") or self.model.padding != "reflect":
self.load_checkpoint(padding="reflect")
Expand Down
105 changes: 56 additions & 49 deletions src/nimbus_inference/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,34 @@ def __call__(self, input_tensor, mask_tensor, label_tensor):


class SmoothBinaryCELoss(torch.nn.Module):
"""Smooth binary cross entropy loss with label smoothing.
Args:
label_smoothing (float): Label smoothing factor.
"""
def __init__(self, label_smoothing=0.05):
super(SmoothBinaryCELoss, self).__init__()
if not 0 <= label_smoothing < 1:
raise ValueError("Label smoothing must be in [0,1)")
self.label_smoothing = label_smoothing
self.eps = 1e-12 # For numerical stability

def forward(self, inputs, targets):
def forward(self, inputs, targets):
""" Compute binary cross entropy loss with label smoothing.
Args:
inputs (torch.Tensor): Model predictions.
targets (torch.Tensor): Target labels 0: negative, 1: positive, 2: ignore.
Returns:
torch.Tensor: Binary cross entropy loss.
"""
# Clamp for numerical stability
inputs = torch.clamp(inputs, self.eps, 1.0 - self.eps)
# get mask which is not (targets == 0 and targets == 1)
mask = torch.clip(targets, 0, 2) == 2
# Apply label smoothing
targets = torch.clip(targets, 0, 1)
targets = targets * (1 - self.label_smoothing) + 0.5 * self.label_smoothing
# Binary cross entropy
loss = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
Expand Down Expand Up @@ -119,7 +134,7 @@ def __init__(
self.validation_dataset = validation_dataset
self.batch_size = batch_size
self.model = nimbus.model
self.device = self.model.device
self.device = nimbus.device
self.initial_regularization = initial_regularization
if self.initial_regularization:
self.initial_checkpoint = deepcopy(self.model)
Expand Down Expand Up @@ -159,7 +174,8 @@ def run_validation(self):
self.model.eval()
loss_ = []
df_list = []
for inputs, labels, inst_mask, key in self.validation_loader:
print("Running validation...")
for inputs, labels, inst_mask, key in tqdm(self.validation_loader):
inputs = inputs.to(self.device)
labels = labels.to(self.device)
with torch.no_grad():
Expand All @@ -169,6 +185,8 @@ def run_validation(self):
# calculate mean per instance prediction
outputs = outputs.cpu().numpy()
inst_mask = inst_mask.cpu().numpy()
binary_mask = inputs[:,1,...].unsqueeze(1).cpu().numpy()
inst_mask = inst_mask * binary_mask
labels = labels.cpu().numpy()
# split batch to get individual samples
for i in range(outputs.shape[0]):
Expand All @@ -186,37 +204,32 @@ def run_validation(self):
merged_df = pd.merge(gt_df, pred_df, on=["fov", "channel", "label"], how="inner")
df_list.append(merged_df)
df = pd.concat(df_list)
metrics = {}
# Calculate metrics per channel
for channel in df['channel'].unique():
channel_df = df[df['channel'] == channel]

# Convert to binary predictions using 0.5 threshold
y_true = (channel_df['gt'] > 0.5).astype(int)
y_pred = (channel_df['pred'] > 0.5).astype(int)

# Calculate basic counts
tp = ((y_true == 1) & (y_pred == 1)).sum()
fp = ((y_true == 0) & (y_pred == 1)).sum()
tn = ((y_true == 0) & (y_pred == 0)).sum()
fn = ((y_true == 1) & (y_pred == 0)).sum()

# Calculate metrics
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

metrics[channel] = {
'precision': precision,
'recall': recall,
'specificity': specificity,
'f1': f1
}
df["gt"] = df["gt"].round(0).astype(int)
metrics = {}
# mask out ambigious cells
df = df[df["gt"] != 2]
# Calculate metrics
y_true = (df['gt'] > 0.5).astype(int)
y_pred = (df['pred'] > 0.5).astype(int)
tp = ((y_true == 1) & (y_pred == 1)).sum()
fp = ((y_true == 0) & (y_pred == 1)).sum()
tn = ((y_true == 0) & (y_pred == 0)).sum()
fn = ((y_true == 1) & (y_pred == 0)).sum()

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

metrics = {
'precision': precision,
'recall': recall,
'specificity': specificity,
'mean_f1': f1
}

# Add mean loss
metrics['loss'] = np.mean(loss_)
metrics['mean_f1'] = np.mean([m['f1'] for m in metrics.values() if isinstance(m, dict)])
return metrics

def initial_checkpoint_regularizer(self):
Expand All @@ -233,48 +246,42 @@ def initial_checkpoint_regularizer(self):

def save_checkpoint(self):
"""Save the model checkpoint to a file."""
# Set up paths
# Set up paths
print("Saving checkpoint...")
path = os.path.dirname(nimbus_inference.__file__)
path = Path(path).resolve()
local_dir = os.path.join(path, "assets")
os.makedirs(local_dir, exist_ok=True)
self.checkpoint_path = os.path.join(local_dir, self.checkpoint_name + ".pt")
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_f1': self.best_f1,
'history': self.history
}, self.checkpoint_path)
self.checkpoint_path = os.path.join(local_dir, self.checkpoint_name)
torch.save(self.model.state_dict(), self.checkpoint_path)

def train(self, epochs: int):
"""Train the model for specified number of epochs.
Args:
epochs (int): Number of epochs to train the model.
"""
print(f"Found device: {self.device}")
val_metrics = self.run_validation()
self._print_epoch_summary(epoch=0, train_losses=[], val_metrics=val_metrics)

for epoch in range(epochs):
# Training phase
self.model.train()
train_losses = []

for inputs, labels, inst_mask, key in tqdm(self.train_loader):
self.optimizer.zero_grad()
inputs, inst_mask, labels = self.augmenter(inputs, inst_mask, labels)
inputs = inputs.to(self.device)
labels = labels.to(self.device)
inputs, inst_mask, labels = self.augmenter(inputs, inst_mask, labels)
self.optimizer.zero_grad()
inst_mask = inst_mask.to(self.device)
outputs = self.model(inputs)
loss = self.loss_function(outputs, labels)
loss = loss.mean()
if self.initial_regularization:
loss += self.initial_checkpoint_regularizer()
loss.backward()

# Gradient clipping
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.gradient_clip
)

self.optimizer.step()
train_losses.append(loss.item())

Expand Down
Loading

0 comments on commit 02d937b

Please sign in to comment.