Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed the stem and forward pass #115

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,74 @@
import torch
from viscy.representation.contrastive import ContrastiveEncoder
import torchview
import timm

# uncomment if you are using jupyter and want to autoreload the updated code.
# %load_ext autoreload
# %autoreload 2

%load_ext autoreload
%autoreload 2
# %% Initialize the model and log the graph.
contra_model = ContrastiveEncoder(backbone = "convnext_tiny") # other options: convnext_tiny resnet50
print(contra_model)
# %% Explore model graphs returned by timm

convnextv1 = timm.create_model(
"convnext_tiny", pretrained=False, features_only=False, num_classes=200
)
print(convnextv1)
output = convnextv1(torch.randn(1, 3, 256, 256))
print(output.shape)
# %% Initialize the model and log the graph: convnext.
in_channels = 1
in_stack_depth = 15

contrastive_convnext1 = ContrastiveEncoder(
backbone="convnext_tiny", in_channels=in_channels, in_stack_depth=in_stack_depth
)
print(contrastive_convnext1)


projections, embedding = contrastive_convnext1(
torch.randn(1, in_channels, in_stack_depth, 256, 256)
)
print(
f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}"
)
# %%

in_channels = 3
in_stack_depth = 18

contrastive_convnext2 = ContrastiveEncoder(
backbone="convnextv2_tiny", in_channels=in_channels, in_stack_depth=in_stack_depth
)
print(contrastive_convnext2)
embedding, projections = contrastive_convnext2(
torch.randn(1, in_channels, in_stack_depth, 256, 256)
)
print(
f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}"
)

# %%
in_channels = 10
in_stack_depth = 12
contrastive_resnet = ContrastiveEncoder(
backbone="resnet50",
in_channels=in_channels,
in_stack_depth=in_stack_depth,
embedding_len=256,
)
print(contrastive_resnet)
embedding, projections = contrastive_resnet(
torch.randn(1, in_channels, in_stack_depth, 256, 256)
)
print(
f"shape of projections:{projections.shape}, shape of embedding: {embedding.shape}"
)

# %%
plot_model = contrastive_resnet
model_graph = torchview.draw_graph(
contra_model,
torch.randn(1, 2, 15, 224, 224),
plot_model,
input_size=(20, in_channels, in_stack_depth, 224, 224),
depth=3, # adjust depth to zoom in.
device="cpu",
)
Expand All @@ -21,7 +79,9 @@
model_graph.visual_graph

# %% Initialize a resent50 model and log the graph.
contra_model = ContrastiveEncoder(backbone = "resnet50", in_stack_depth = 16, stem_kernel_size = (4, 3, 3)) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3)
contra_model = ContrastiveEncoder(
backbone="resnet50", in_stack_depth=16, stem_kernel_size=(4, 3, 3)
) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3)
print(contra_model)
model_graph = torchview.draw_graph(
contra_model,
Expand All @@ -41,15 +101,14 @@
# %%
model_graph = torchview.draw_graph(
contrastive_module.encoder,
torch.randn(1, 2, 15, 200, 200),
torch.randn(1, in_channels, in_stack_depth, 200, 200),
depth=3, # adjust depth to zoom in.
device="cpu",
)
# Print the image of the model.
model_graph.visual_graph

# %% Playground
import timm

available_models = timm.list_models(pretrained=True)

Expand Down
39 changes: 30 additions & 9 deletions viscy/applications/contrastive_phenotyping/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from lightning.pytorch.strategies import DDPStrategy
from viscy.data.hcs import ContrastiveDataModule
from viscy.light.engine import ContrastiveModule
import os
import os


def main(hparams):
# Set paths
Expand Down Expand Up @@ -35,7 +36,7 @@ def main(hparams):
batch_size=batch_size,
z_range=z_range,
predict_base_path=predict_base_path,
analysis=True, # for self-supervised results
analysis=True, # for self-supervised results
)

data_module.setup(stage="predict")
Expand All @@ -54,7 +55,7 @@ def main(hparams):

# Run prediction
predictions = trainer.predict(model, datamodule=data_module)

# Collect features and projections
features_list = []
projections_list = []
Expand All @@ -66,13 +67,33 @@ def main(hparams):
all_features = np.concatenate(features_list, axis=0)
all_projections = np.concatenate(projections_list, axis=0)

# for saving visualizations embeddings
# for saving visualizations embeddings
base_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/5-finaltrack/test_visualizations"
features_path = os.path.join(base_dir, 'B', '4', '2', 'before_projected_embeddings', 'test_epoch88_predicted_features.npy')
projections_path = os.path.join(base_dir, 'B', '4', '2', 'projected_embeddings', 'test_epoch88_predicted_projections.npy')
features_path = os.path.join(
base_dir,
"B",
"4",
"2",
"before_projected_embeddings",
"test_epoch88_predicted_features.npy",
)
projections_path = os.path.join(
base_dir,
"B",
"4",
"2",
"projected_embeddings",
"test_epoch88_predicted_projections.npy",
)

np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_features.npy", all_features)
np.save("/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_projections.npy", all_projections)
np.save(
"/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_features.npy",
all_features,
)
np.save(
"/hpc/mydata/alishba.imran/VisCy/viscy/applications/contrastive_phenotyping/ss1_epoch97_predicted_projections.npy",
all_projections,
)


if __name__ == "__main__":
Expand All @@ -89,4 +110,4 @@ def main(hparams):
parser.add_argument("--num_nodes", type=int, default=2)
parser.add_argument("--log_every_n_steps", type=int, default=1)
args = parser.parse_args()
main(args)
main(args)
Loading
Loading