Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SemanticSamTrainer #637

Merged
merged 10 commits into from
Jun 21, 2024
1 change: 1 addition & 0 deletions micro_sam/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .util import ConvertToSamInputs, get_trainable_sam_model, identity
from .joint_sam_trainer import JointSamTrainer, JointSamLogger
from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer
from .semantic_sam_trainer import SemanticSamTrainer
from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS
Empty file.
119 changes: 119 additions & 0 deletions micro_sam/training/models/build_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

#
# NOTE: This code has been adapted from Segment Anything.
# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py
# This is done in favor of exposing some of the model's hard-coded input parameters for:
# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation)
#


import torch

from functools import partial

from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer


def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
num_multimask_outputs=num_multimask_outputs,
)


build_sam = build_sam_vit_h


def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
num_multimask_outputs=num_multimask_outputs,
)


def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
num_multimask_outputs=num_multimask_outputs,
)


sam_model_registry = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}


def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
num_multimask_outputs=3,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=num_multimask_outputs,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
92 changes: 92 additions & 0 deletions micro_sam/training/semantic_sam_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import time

import torch
from torch.nn import CrossEntropyLoss

from torch_em.trainer import DefaultTrainer


class SemanticSamTrainer(DefaultTrainer):
"""
"""
def __init__(
self,
convert_inputs,
num_classes: int = 1,
**kwargs
):
super().__init__(**kwargs)
self.convert_inputs = convert_inputs
self.num_classes = num_classes
self.compute_ce_loss = CrossEntropyLoss()
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
self._kwargs = kwargs

def _compute_loss(self, y, gt_logits, masks, mask_logits):
# Compute dice loss for the predictions
dice_loss = self.loss(masks, y.to(self.device, non_blocking=True))

# Compute cross entropy loss for the logits
ce_loss = self.compute_ce_loss(mask_logits, gt_logits.to(self.device, non_blocking=True))

net_loss = dice_loss + ce_loss
return net_loss

def _get_model_outputs(self, batched_inputs):
image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1))
masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
mask_logits = torch.stack([output["low_res_masks"].squeeze(0) for output in batched_outputs])
return masks, mask_logits

def _train_epoch_impl(self, progress, forward_context, backprop):
self.model.train()

t_per_iter = time.time()
for x, y in self.train_loader:
self.optimizer.zero_grad()

batched_inputs, gt_logits = self.convert_inputs(x, y)

with forward_context():
masks, mask_logits = self._get_model_outputs(batched_inputs)
net_loss = self._compute_loss(y, gt_logits, masks, mask_logits)

backprop(net_loss)

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True)

self._iteration += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)

t_per_iter = (time.time() - t_per_iter)
return t_per_iter

def _validate_impl(self, forward_context):
self.model.eval()

metric_val, loss_val = 0.0, 0.0

with torch.no_grad():
for x, y in self.val_loader:
batched_inputs, gt_logits = self.convert_inputs(x, y)

with forward_context():
masks, mask_logits = self._get_model_outputs(batched_inputs)
net_loss = self._compute_loss(y, gt_logits, masks, mask_logits)

loss_val += net_loss.item()
metric_val += net_loss.item()

loss_val /= len(self.val_loader)
metric_val /= len(self.val_loader)
print()
print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}")

if self.logger is not None:
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks)

return metric_val
39 changes: 38 additions & 1 deletion micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import torch

from skimage.transform import resize

from segment_anything.utils.transforms import ResizeLongestSide

from ..prompt_generators import PointAndBoxPromptGenerator
Expand Down Expand Up @@ -45,6 +47,8 @@ def get_trainable_sam_model(
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
flexible_load_checkpoint: bool = False,
**model_kwargs
) -> TrainableSAM:
"""Get the trainable sam model.

Expand All @@ -59,14 +63,23 @@ def get_trainable_sam_model(
return_state: Whether to return the full checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.

Returns:
The trainable segment anything model.
"""
# set the device here so that the correct one is passed to TrainableSAM below
device = get_device(device)
_, sam, state = get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True, return_state=True
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
return_sam=True,
return_state=True,
use_lora=use_lora,
rank=rank,
flexible_load_checkpoint=flexible_load_checkpoint,
**model_kwargs
)

# freeze components of the model if freeze was passed
Expand All @@ -85,6 +98,7 @@ def get_trainable_sam_model(
if name.startswith(f"{freeze}"):
param.requires_grad = False

# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything
if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
Expand Down Expand Up @@ -210,6 +224,29 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None):
return batched_inputs, batched_sampled_cell_ids_list


class ConvertToSemanticSamInputs:
"""
"""
def __call__(self, x, y):
"""Convert the outputs of dataloader to the batched format of inputs expected by SAM.
"""
batched_inputs, gt_logits = [], []
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
for image, gt in zip(x, y):
batched_input = {"image": image, "original_size": image.shape[1:]}
batched_inputs.append(batched_input)

# downsize the labels
gt_shape = (gt.shape[0], 256, 256)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
gt_logits.append(
resize(image=gt, output_shape=gt_shape, preserve_range=True, order=0, anti_aliasing=False)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
)

gt_logits = np.stack(gt_logits)
gt_logits = torch.from_numpy(gt_logits)

return batched_inputs, gt_logits


#
# Raw and Label Transformations for the Generalist and Specialist finetuning
#
Expand Down
53 changes: 51 additions & 2 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def get_sam_model(
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
flexible_load_checkpoint: bool = False,
**model_kwargs,
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.

Expand Down Expand Up @@ -306,6 +308,7 @@ def get_sam_model(
return_state: Return the unpickled checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.

Returns:
The segment anything predictor.
Expand Down Expand Up @@ -350,15 +353,29 @@ def get_sam_model(
)

state, model_state = _load_checkpoint(checkpoint_path)
sam = sam_model_registry[abbreviated_model_type]()

# Whether to update parameters necessary to initialize the model
if bool(model_kwargs): # Checks whether model_kwargs have been provided or not
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
if abbreviated_model_type == "vit_t":
raise ValueError("'micro-sam' does not allow changing the model parameters for 'mobile-sam'.")

from micro_sam.training.models.build_sam import sam_model_registry # noqa

sam = sam_model_registry[abbreviated_model_type](**model_kwargs)

# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything
if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
from micro_sam.training.peft_sam import PEFT_Sam
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam

sam.load_state_dict(model_state)
# In case the model checkpoints have some issues when it is initialized with different parameters than default.
if flexible_load_checkpoint:
sam = _handle_checkpoint_loading(sam, model_state)
else:
sam.load_state_dict(model_state)

sam.to(device=device)

predictor = SamPredictor(sam)
Expand All @@ -379,6 +396,38 @@ def get_sam_model(
return predictor


def _handle_checkpoint_loading(sam, model_state):
# Whether to handle the mismatch issues in a bit more elegant way.
# eg. while training for multi-class semantic segmentation in the mask encoder,
# parameters are updated - leading to "size mismatch" errors

new_state_dict = {} # for loading matching parameters
mismatched_layers = [] # for tracking mismatching parameters

reference_state = sam.state_dict()

for k, v in model_state.items():
if reference_state[k].size() == v.size():
new_state_dict[k] = v
else:
mismatched_layers.append(k)

reference_state.update(new_state_dict)

if len(mismatched_layers) > 0:
print(f"The layers with size mismatch: {mismatched_layers}")
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

for mlayer in mismatched_layers:
if 'weight' in mlayer:
torch.nn.init.kaiming_uniform_(reference_state[mlayer])
elif 'bias' in mlayer:
reference_state[mlayer].zero_()

sam.load_state_dict(reference_state)

return sam


def export_custom_sam_model(
checkpoint_path: Union[str, os.PathLike],
model_type: str,
Expand Down
Loading