Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cl fig s6 #38

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# stuff I added
example_scripts/demo_output/
example_scripts/sample_data/
example_scripts/simple_example.ipynb
data/
results/
science_paper_repo/
data
results
slurm_logs
inference/results
inference/viz
tmp*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
# about
This repo is inference for image-based orphan protein analysis in the paper [Global organelle profiling reveals subcellular localization and remodeling at proteome scale](https://www.biorxiv.org/content/10.1101/2023.12.18.572249v1). This analysis is by [James Burgess](https://jmhb0.github.io/) and [Chad Liu](https://www.linkedin.com/in/chad-liu-14a77749/).

It copies the [cytoself repo](https://github.com/royerlab/cytoself) for training a Cytoself representation learning model. We modify the training a little bit (see next section), and then write some inference code for comparing representations of orphan proteins to representations from [Opencell](https://opencell.czbiohub.org/). In this Readme, the content after the section called `cytoself` is from the README of the original project.


# changes from the original cytoself (the training part)
- changed `requirements.txt` for torch and torchvision. Previously it was (`torch>=1.11`) but I was getting some error in the torch Upsample layer. I think it was some issue with torch 2.0, so I set it to `torch==1.13.1` and `torchvision==0.12`.
- scripts to train cytoself from scratch in `scripts`
- added the 10 `.npy` file and 10 `.csv` files from `https://github.com/royerlab/cytoself/tree/main` and put them in `data/opencell_crops` (which is obviously not comitted to the repo).
- In the run scripts, I save the `train_args` and `model_args`. This is so that I can more easily reconstruct the model in the inference scripts.
- in the `cytoself/datamanager/opencell.py` I add a step to the end of the dataloading of the dataloader. This saves the test dataset - it's indices and its crops into separate files. This is useful for doing analysis afterwards. You can load the embeddings and these labeles or images. It's necessary for doing inference with a new dataset. That puts data files in 'data/test_dataset_metadata/`
- TODO: include the `data/cz_infectedcell_finalwellmapping.csv` in the eventual repo.

# inference pipeline
The new inference code for is in `inference/`.
- `python inference/load_inf_data.py` saves stacks as max-intensity projs, and does some fov-level normalizations. Saved to `results/load_inf_data`.
- `python inference/nuclear_segmentation.py`. Saves masks to `inference/nuclear_segmentation/all_masks.pt`. Issue is that if you take only a subset of these images, then you overwrite the pt file with only these images (this is not true of other steps)
- `python inference/crop.py`. For each segmented nucleus in the fov, take a crop around it. There is a `VERSION` parameter that controls how normalization is done. If `VERSION==0` (recommended), then do `[0,1]` normalization within each crop, treating nucleus and target channel independently. If `VERSION==1` then normalize the FOV before cropping. IMO this is worse because, for example, you can get very bright (e.g. due to mitosis I guess), so doing normalization at the whole-FOV level makes the rest of the image very dim. If you normalize at the crop level, then only the abnormally bright areas are affected. Results saved in `inference/results/crop/` as `crops_v0` for `VERSION==0` or `crops_v1` for `VERSION==1`. Also the crop-metadata is saved to `crops_meta.csv`, which has the fov filename, the centroid coords (in the fov-space) of the nucleus that is centered in this crop, and some other stuff.
- `python inference/get_crop_features.py` loads the pretrained cytoself model. For model name, for example `20240129_train_all`, features saved to `inference/results/get_crop_features/results/20240129_train_all/ckpt_None/`. Need to define the pretrained models in the bottom of the script. Also an option to create features for rotated versions of the crops, which should give better robustness overall according to [this paper](https://www.nature.com/articles/s41467-024-45362-4). (you have to make sure the `VERSION` matches what was used in the cropping - sorry, this should have been handled automatically)
- `python inference/compare_opencell_targets.py` gets the 'consensus embeddings' for each protein by averaging over the crop representations for that protein. It does it for opencell and orphans, and makes a distance matrix for all proteins.




# cytoself

[![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-397/)
Expand Down
103 changes: 103 additions & 0 deletions analysis/image_processing/20231010_generate_crops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from pathlib import Path
import ipdb
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import os

current_filename = Path(os.path.basename(__file__))
results_dir = Path("analysis/results") / current_filename.stem

fname_mask = Path('analysis/results/20231013_segment_nuclei/all_segmasks.pt')
masks, fnames_nuc = torch.load(fname_mask)

width, height = 200, 200
crops_all = []
df_meta_all = []

def norm_0_1(x):
l, u = x.min(), x.max()
if u == l:
raise ValueError("constant image pixel vaues")
return (x-l) / (u-l)

for i, mask in enumerate(masks):
crops = []
df_meta = pd.DataFrame(columns=["fname_pro", "fname_nuc", "cell_centroid", "mask_idx"])

n_instances = len(np.unique(mask))-1
f_nuc = fnames_nuc[i]
f_pro = str(f_nuc)[:-7] + "pro.png"
img_nuc = np.array(Image.open(f_nuc))
img_pro = np.array(Image.open(f_pro))

for j in range(1,n_instances+2):
target_indices = np.argwhere(mask == j)
centroid = target_indices.mean(axis=0).astype(int)

# do image crop, normalize each channel independently.
# and check that it's not hitting the image border
# ipdb.set_trace()
slc_0 = slice(centroid[0] - height//2, centroid[0] + height//2)
slc_1 = slice(centroid[1] - width//2, centroid[1] + width//2)
img_pro_crop = img_pro[slc_0, slc_1]
img_nuc_crop = img_nuc[slc_0, slc_1]
if img_pro_crop.shape != (width, height):
continue # skip the crop bc we hit a border
img_crop = np.stack((
norm_0_1(img_pro_crop),
norm_0_1(img_nuc_crop),
))

# add the crop and metadata
crops.append(img_crop)
row_meta = pd.DataFrame(dict(
fname_pro=[f_pro],
fname_nuc=[f_nuc],
cell_centroid=[centroid],
mask_idx=[j],
))
df_meta = pd.concat([df_meta, row_meta])


# optionally visualize crops in a grid
if 1:
if len(crops) < 1:
continue
from torchvision.utils import make_grid
crops_arr = torch.from_numpy(np.stack(crops)) # (N,2,H,W)

grid_nuc = make_grid(crops_arr[:,[1]], nrow=8, pad_value=0.5)
f, axs = plt.subplots()
axs.imshow(grid_nuc.permute(1,2,0))
f.savefig(results_dir / (f"{f_nuc.stem}" + "nucleus_only.png"))
plt.close()

grid_pro = make_grid(crops_arr[:,[0]], nrow=8, pad_value=0.5)
f, axs = plt.subplots()
axs.imshow(grid_pro.permute(1,2,0))
f.savefig(results_dir / (f"{f_nuc.stem}" + "protein_only.png"))
plt.close()


# # for visualization, put protein in green, the nucleus in red for RGBs
# imgs_rgb_tmp = torch.from_numpy(np.stack(crops)).clone()
# imgs_rgb = torch.zeros((len(imgs_rgb_tmp), 3, *imgs_rgb_tmp.shape[2:]))
# imgs_rgb[:,1] = imgs_rgb_tmp[:,0]
# imgs_rgb[:,2] = imgs_rgb_tmp[:,1]

# grid = make_grid(imgs_rgb, nrow=5)
# f, axs = plt.subplots()
# axs.imshow(grid.permute(1,2,0))
# fname_save = results_dir / f""
# f.savefig()
# # gri
crops_all.extend(crops)
df_meta_all.append(df_meta)



ipdb.set_trace()
pass
97 changes: 97 additions & 0 deletions analysis/image_processing/20231010_lookat_imageset48.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# should be run in the project dir for cytoself
# for developing the inference pipeline, let's process the images

import ipdb
from PIL import Image
from aicsimageio import AICSImage
from pathlib import Path
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from skimage.io import imread, imsave
from skimage import color
from PIL import Image

def norm_0_1(x):
l, u = x.min(), x.max()
if u == l:
raise ValueError("constant image pixel vaues")
return (x-l) / (u-l)

# make the results_dir
current_filename = Path(os.path.basename(__file__))
results_dir = Path("analysis/results") / current_filename.stem
results_dir.mkdir(exist_ok=True, parents=True)

results_maxproj_set48 = Path("analysis/results/20231010_maxproj_set48")
results_maxproj_set48.mkdir(exist_ok=True, parents=True)

DATA_DIR = Path("/hpc/instruments/leonetti.dragonfly/infected-cell-microscopy/TICM048-1/raw_data")
fnames = [f for f in DATA_DIR.glob("*.tif")]
print(f"Num files {len(fnames)}")

# get the orphan proteins in `df_exp` and filter for the ones with good signal have col is_greed
fname_exp = Path("data/orphan_protein.csv")
df_exp = pd.read_csv(fname_exp)
df_exp_green = df_exp[df_exp['is_green']==1]

# construct a dataframe for the image metadata
df_imgs_meta = pd.DataFrame(columns=["fname", "well_id", "fov_id", "is_green"])
for fname in fnames:

# metadata structure "raw_data_MMStack_592-B8-16.ome.tif". Well=B8, FOV=16
well_id, fov_id = fname.stem.split(".")[0].split("-")[-2:]

idxs_well_id = np.where(well_id == df_exp['well_id'].values)[0]
if len(idxs_well_id)==0:
print(f"did not find meta for well_id={well_id}")
elif len(idxs_well_id)>1:
print(f"More than one match for well_id={well_id}")
else:
assert len(idxs_well_id) == 1
row = pd.DataFrame(dict(
fname=[fname],
well_id=[well_id],
fov_id=[fov_id],
is_green=[df_exp.iloc[idxs_well_id[0]]['is_green']],
))
df_imgs_meta = pd.concat([df_imgs_meta, row], ignore_index=True)

df_imgs_meta_green = df_imgs_meta[df_imgs_meta['is_green']==1]
print(f"Number of zstacks that are 'green' {len(df_imgs_meta_green)}")
print(f"Number of unique wells that are 'green' {len(df_imgs_meta_green.well_id.unique())}")

# now get the zstacks for the green images
for idx, row in df_imgs_meta_green.iterrows():
fname = row.fname
aics_img = AICSImage(fname)
# ipdb.set_trace()
x = aics_img.data # (1, 2, Z, H, W)
assert x.ndim == 5 and x.shape[:2] == (1,2)

# max intensity projection for the nucleus and protein channel independently
x_pro_map = x[0,1].max(0)
x_nuc_map = x[0,0].max(0)

x_pro_map = norm_0_1(x_pro_map)
x_nuc_map = norm_0_1(x_nuc_map)

# save as 1-channel png
fname_nuc = results_maxproj_set48 / f"max_proj_{row.well_id}_{row.fov_id}_nuc.png"
fname_pro = results_maxproj_set48 / f"max_proj_{row.well_id}_{row.fov_id}_pro.png"
Image.fromarray((x_nuc_map*255).astype(np.uint8), mode="L").save(fname_nuc)
Image.fromarray((x_pro_map*255).astype(np.uint8), mode="L").save(fname_pro)

# viewing
f, axs = plt.subplots(1,2)
axs[0].imshow(x_nuc_map, cmap='gray')
axs[1].imshow(x_pro_map, cmap='gray')
axs[0].set(title="nucleus")
axs[1].set(title="protein")
fname_save = results_dir / (fname.stem + "_max_projs.png")
f.savefig(fname_save, dpi=200)
plt.close()

ipdb.set_trace()
pass
56 changes: 56 additions & 0 deletions analysis/image_processing/20231010_validate_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# run in the home directory as python analysis/20231010_validate_embeddings.py
import ipdb
import os
from pathlib import Path
import numpy as np
import torch
import pandas as pd

# load embeddings from a results dir
RESULTS_DIR = Path("results/20231009_train_all_no_nucdist")
embed_images = torch.from_numpy(np.load(RESULTS_DIR / "embeddings/vqvec2.npy" ))
embed_images = embed_images.view((len(embed_images), -1)) # flatten the array

# load labels for the test that was saved during datamanager
DIR_TEST_DATASET_META = Path("data/test_dataset_metadata")
labels = np.load(DIR_TEST_DATASET_META / "test_dataset_labels.npy", allow_pickle=True)
df = pd.DataFrame(labels, columns=["ensg", "name", "loc_grade"])

# check that embeddings and labels are at least the same dimension
assert len(embed_images) == len(labels)
prots = sorted(df.name.unique())
assert len(prots) == 1311

# consensus embedding as a mean
embed_consensus = []
img_counts = []
localization_labels = []

for prot in prots:
idxs = np.where(df.name==prot)[0]
z = embed_images[idxs]
embed_consensus.append(z.mean(0))
img_counts.append(len(z))

localization_label = df[df.name==prot].loc_grade.unique()
localization_labels.append(localization_label[0])

embed_consensus = torch.stack(embed_consensus) # (n_prot, 1024)
img_counts = torch.tensor(img_counts)
dist = torch.cdist(embed_consensus, embed_consensus)
argsortdist = torch.argsort(dist, dim=1, descending=False)
assert torch.all(argsortdist[:,0] == torch.arange(len(argsortdist))) # nearest neighbor should be itself
argsortdist = argsortdist[:,1:]

ipdb.set_trace()
pass
# lets do the most basic possible test based on L2 distance which, remember, is not actually the proposed method by Kobayashi et al.
for i in range(20):
print(localization_labels[i], end= " ")
print(localization_labels[argsortdist[i,0].item()], "\t\t\t", localization_labels[argsortdist[i,1].item()])




if __name__=="__main__":
pass
58 changes: 58 additions & 0 deletions analysis/image_processing/20231013_segment_nuclei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# options: choose a diameter

diameter = 120
diameter = None
from pathlib import Path
from PIL import Image
import numpy as np
import ipdb
import torch
import os

def norm_0_1(x):
l, u = x.min(), x.max()
if u == l:
raise ValueError("constant image pixel vaues")
return (x-l) / (u-l)

current_filename = Path(os.path.basename(__file__))
results_dir = Path("analysis/results") / current_filename.stem
results_dir.mkdir(exist_ok=True, parents=True)
results_dir_cellpose = results_dir / "cellpose_segs"
results_dir_cellpose.mkdir(exist_ok=True, parents=True)

results_dir_tmp = Path("analysis/tmp")
results_dir_tmp.mkdir(exist_ok=True)

# baseline code from the Colab https://colab.research.google.com/github/MouseLand/cellpose/blob/main/notebooks/run_cellpose_GPU.ipynb

### our data
data_maxproj_set48 = Path("analysis/results/20231010_maxproj_set48")
fnames = [l for l in data_maxproj_set48.iterdir()]
fnames_pro = sorted([l for l in fnames if "_pro.png" in str(l)])
fnames_nuc = sorted([l for l in fnames if "_nuc.png" in str(l)])


# imgs_pro = [np.array(Image.open(f)) for f in fnames_pro]
imgs_nuc = [np.array(Image.open(f)) for f in fnames_nuc]

from cellpose import models
model = models.Cellpose(gpu=True, model_type='cyto')
print("Running cellpose")
masks, flows, styles, diams = model.eval(imgs_nuc, diameter=diameter,
flow_threshold=None, channels=None)
fname_mask = results_dir / f"all_segmasks.pt"
torch.save([masks, fnames_nuc], fname_mask)
ipdb.set_trace()

import matplotlib.pyplot as plt
# save the mask and also visualize them
for i in range(len(masks)):
f, axs = plt.subplots(1,2)
axs[0].imshow(imgs_nuc[i],cmap='gray')
axs[1].imshow(masks[i], cmap='gray')
f.savefig(results_dir / f"seg_sample_{fnames[i].stem}.png", dpi=200)
plt.close()

plt.close()

Loading