Skip to content

Commit

Permalink
Added ome-tiff inference pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
JLrumberger committed Apr 4, 2024
1 parent 178dccd commit a53a8b1
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 27 deletions.
27 changes: 15 additions & 12 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from alpineer import io_utils, misc_utils
from skimage.util.shape import view_as_windows
import nimbus_inference
from nimbus_inference.utils import (
prepare_normalization_dict,
predict_fovs,
predict_ome_fovs,
nimbus_preprocess,
from nimbus_inference.utils import (prepare_normalization_dict, prepare_normalization_dict_ome,
predict_fovs, predict_ome_fovs, nimbus_preprocess,
)
from huggingface_hub import hf_hub_download
from nimbus_inference.unet import UNet
Expand Down Expand Up @@ -161,14 +158,14 @@ def initialize_model(self, padding="reflect"):
self.checkpoint_path = os.path.join(
path,
"assets",
"resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt"
"resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt"
)
if not os.path.exists(self.checkpoint_path):
local_dir = os.path.join(path, "assets")
print("Downloading weights from Hugging Face Hub...")
self.checkpoint_path = hf_hub_download(
repo_id="JLrumberger/Nimbus-Inference",
filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt",
filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt",
local_dir=local_dir,
local_dir_use_symlinks=False,
)
Expand All @@ -192,11 +189,17 @@ def prepare_normalization_dict(
if os.path.exists(self.normalization_dict_path) and not overwrite:
self.normalization_dict = json.load(open(self.normalization_dict_path))
else:

n_jobs = os.cpu_count() if multiprocessing else 1
self.normalization_dict = prepare_normalization_dict(
self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, n_jobs
)
if self.suffix.lower() in [".ome.tif", ".ome.tiff"]:
self.normalization_dict = prepare_normalization_dict_ome(
self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset,
n_jobs
)
else:
self.normalization_dict = prepare_normalization_dict(
self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset,
n_jobs
)
if self.include_channels == []:
self.include_channels = list(self.normalization_dict.keys())

Expand All @@ -223,7 +226,7 @@ def predict_fovs(self):
half_resolution=self.half_resolution, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug, suffix=self.suffix,
)
elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]:
else:
self.cell_table = predict_fovs(
nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir,
normalization_dict=self.normalization_dict,
Expand Down
160 changes: 156 additions & 4 deletions src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from joblib.externals.loky import get_reusable_executor
from skimage.segmentation import find_boundaries
from skimage.measure import regionprops_table
from pyometiff import OMETIFFReader


def calculate_normalization(channel_path, quantile):
Expand Down Expand Up @@ -196,6 +197,8 @@ def predict_fovs(
os.path.normpath(output_dir), os.path.basename(fov_path)
)
df_fov = pd.DataFrame()
instance_path = segmentation_naming_convention(fov_path)
instance_mask = np.squeeze(io.imread(instance_path))
for channel in tqdm(os.listdir(fov_path)):
channel_path = os.path.join(fov_path, channel)
channel_ = channel.split(".")[0]
Expand All @@ -204,8 +207,6 @@ def predict_fovs(
):
continue
mplex_img = np.squeeze(io.imread(channel_path))
instance_path = segmentation_naming_convention(fov_path)
instance_mask = np.squeeze(io.imread(instance_path))
input_data = prepare_input_data(mplex_img, instance_mask)
if half_resolution:
scale = 0.5
Expand Down Expand Up @@ -275,5 +276,156 @@ def nimbus_preprocess(image, **kwargs):
return output


def predict_ome_fovs():
pass
def calculate_normalization_ome(ome_path, quantile, include_channels):
"""Calculates the normalization values for a given ome file
Args:
ome_path (str): path to ome file
quantile (float): quantile to use for normalization
include_channels (list): list of channels to include
Returns:
normalization_values (dict): dict with channel names as keys and norm factors as values
"""
reader = OMETIFFReader(fpath=ome_path)
img_array, metadata, _ = reader.read()
channel_names = list(metadata["Channels"].keys())
if not include_channels:
include_channels = channel_names
# check if include_channels are included in ome file metadata
for channel in include_channels:
if channel not in channel_names:
raise ValueError(f"Channel {channel} not found in ome file metadata.")
normalization_values = {}
for channel in include_channels:
idx = channel_names.index(channel)
mplex_img = img_array[idx]
mplex_img = mplex_img.astype(np.float32)
foreground = mplex_img[mplex_img > 0]
normalization_values[channel] = np.quantile(foreground, quantile)
return normalization_values


def prepare_normalization_dict_ome(
fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1,
output_name="normalization_dict.json"
):
"""Prepares the normalization dict for a list of ome.tif fovs
Args:
fov_paths (list): list of paths to fovs
output_dir (str): path to output directory
quantile (float): quantile to use for normalization
exclude_channels (list): list of channels to exclude
n_subset (int): number of fovs to use for normalization
n_jobs (int): number of jobs to use for joblib multiprocessing
output_name (str): name of output file
Returns:
normalization_dict (dict): dict with channel names as keys and norm factors as values
"""
normalization_dict = {}
if n_subset is not None:
random.shuffle(fov_paths)
fov_paths = fov_paths[:n_subset]
print("Iterate over fovs...")
if n_jobs > 1:
normalization_values = Parallel(n_jobs=n_jobs)(
delayed(calculate_normalization_ome)(ome_path, quantile, include_channels)
for ome_path in fov_paths
)
else:
normalization_values = [
calculate_normalization_ome(ome_path, quantile, include_channels)
for ome_path in fov_paths
]
for norm_dict in normalization_values:
for channel, normalization_value in norm_dict.items():
if channel not in normalization_dict:
normalization_dict[channel] = []
normalization_dict[channel].append(normalization_value)
if n_jobs > 1:
get_reusable_executor().shutdown(wait=True)
for channel in normalization_dict.keys():
normalization_dict[channel] = np.mean(normalization_dict[channel])
# save normalization dict
with open(os.path.join(output_dir, output_name), 'w') as f:
json.dump(normalization_dict, f)
return normalization_dict


def predict_ome_fovs(
nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir,
suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4,
test_time_augmentation=True
):
"""Predicts the segmentation map for each mplex channel in each ome.tif fov
Args:
nimbus (Nimbus): nimbus object
fov_paths (list): list of fov paths
normalization_dict (dict): dict with channel names as keys and norm factors as values
segmentation_naming_convention (function): function to get instance mask path from fov path
output_dir (str): path to output dir
suffix (str): suffix of mplex images
include_channels (list): list of channels to include
save_predictions (bool): whether to save predictions
half_resolution (bool): whether to use half resolution
batch_size (int): batch size
test_time_augmentation (bool): whether to use test time augmentation
Returns:
cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell
"""
fov_dict_list = []
for fov_path in fov_paths:
print(f"Predicting {fov_path}...")
out_fov_path = os.path.join(
os.path.normpath(output_dir), os.path.basename(fov_path).split(".")[0]
)
df_fov = pd.DataFrame()
reader = OMETIFFReader(fpath=fov_path)
img_array, metadata, _ = reader.read()
channel_names = list(metadata["Channels"].keys())
instance_path = segmentation_naming_convention(fov_path)
instance_mask = np.squeeze(io.imread(instance_path))
if not include_channels:
include_channels = channel_names
for channel in tqdm(include_channels):
idx = channel_names.index(channel)
mplex_img = np.squeeze(img_array[idx])
input_data = prepare_input_data(mplex_img, instance_mask)
if half_resolution:
scale = 0.5
input_data = np.squeeze(input_data)
_, h,w = input_data.shape
img = cv2.resize(input_data[0], [int(h*scale), int(w*scale)])
binary_mask = cv2.resize(
input_data[1], [int(h*scale), int(w*scale)], interpolation=0
)
input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...]
if test_time_augmentation:
prediction = test_time_aug(
input_data, channel, nimbus, normalization_dict, batch_size=batch_size
)

else:
prediction = nimbus.predict_segmentation(
input_data,
preprocess_kwargs={
"normalize": True, "marker": channel,
"normalization_dict": normalization_dict
},
)
prediction = np.squeeze(prediction)
if half_resolution:
prediction = cv2.resize(prediction, (h, w))
df = pd.DataFrame(segment_mean(instance_mask, prediction))
if df_fov.empty:
df_fov["label"] = df["label"]
df_fov["fov"] = os.path.basename(fov_path)
df_fov[channel] = df["intensity_mean"]
if save_predictions:
os.makedirs(out_fov_path, exist_ok=True)
pred_int = (prediction*255.0).astype(np.uint8)
io.imwrite(
os.path.join(out_fov_path, channel+".tiff"), pred_int, photometric="minisblack",
# compress=0,
)
fov_dict_list.append(df_fov)
cell_table = pd.concat(fov_dict_list, ignore_index=True)
return cell_table
Loading

0 comments on commit a53a8b1

Please sign in to comment.