Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAL_CENTER Indice in camelyon17_dataset.py #165

Open
David-Drexlin opened this issue Nov 19, 2024 · 0 comments
Open

VAL_CENTER Indice in camelyon17_dataset.py #165

David-Drexlin opened this issue Nov 19, 2024 · 0 comments

Comments

@David-Drexlin
Copy link

David-Drexlin commented Nov 19, 2024

Hi everyone,

I hope this is the right place to ask a question about the Camelyon17 dataset. My question is regarding the center-metadata indices for TEST_CENTER and VAL_CENTER, as defined in the camelyon17_dataset.py file. According to that file, the test and validation (OOD) centers are 0-indexed, with TEST_CENTER at index 2 and VAL_CENTER at index 1. My understanding is that this should correspond to the images shown in columns 5 and 4 of the paper (see the first image for reference). Is that correct?

When I naively plot the images according to their center labels per row (see the second image), I would expect the images for indices 2 and 1 to show the test and validation (OOD) slides in row 2 and 1 (zero-index) as well. Instead, it seems like the (validation) center indices are switched, with the test images corresponding to index 2 and validation (OOD) to index 4 instead of 1. Also inspecting the images directly in the data/patches directory showcases this behaviour e.g. patient 96 from center 4 seems to be Val (ODD) and e.g. patient 34 from center 1 seems to be part of train, at least visually to a layman.

Did I misunderstand something in the indexing or do you have any clue what could be wrong? Below are the images for reference and the code I used to generate them:

Wilds slides:
camelyon_dataset

Slides as per my Code: slides_per_domain_class

Thanks in advance for any clarification!

Code:

import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict

# constants
DATA_DIR = '/data/camelyon17_v1.0'
PATCHES_DIR = os.path.join(DATA_DIR, 'patches')
METADATA_CSV = os.path.join(DATA_DIR, 'metadata.csv')
MAX_IMAGES_PER_COMBINATION = 5
NUM_DOMAINS = 5  
NUM_CLASSES = 2  

# Load the metadata
metadata_df = pd.read_csv(
    METADATA_CSV,
    index_col=0,
    dtype={'patient': 'str'}
)

# Get labels
y_array = torch.LongTensor(metadata_df['tumor'].values)

# Get input image paths
input_paths = [
    os.path.join(
        PATCHES_DIR,
        f'patient_{patient}_node_{node}',
        f'patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
    )
    for patient, node, x, y in metadata_df[['patient', 'node', 'x_coord', 'y_coord']].values
]

# Get domains (centers)
centers = metadata_df['center'].astype(int).values

# Organize images into a dictionary keyed by (domain, class)
images_dict = defaultdict(list)

for img_path, label, domain in zip(input_paths, y_array, centers):
    key = (domain, label.item())
    if len(images_dict[key]) < MAX_IMAGES_PER_COMBINATION:
        try:
            img = Image.open(img_path).convert('RGB')
            images_dict[key].append(img)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")

# plot 
fig, axes = plt.subplots(nrows=NUM_DOMAINS, ncols=NUM_CLASSES * MAX_IMAGES_PER_COMBINATION, figsize=(24, 12))
plt.subplots_adjust(wspace=0.05, hspace=0.05)

for domain_idx in range(NUM_DOMAINS):
    for class_idx in range(NUM_CLASSES):
        key = (domain_idx, class_idx)
        images = images_dict.get(key, [])
        for img_idx in range(MAX_IMAGES_PER_COMBINATION):
            col_idx = class_idx * MAX_IMAGES_PER_COMBINATION + img_idx
            ax = axes[domain_idx, col_idx]
            if img_idx < len(images):
                ax.imshow(images[img_idx])
            ax.axis('off')

            if domain_idx == 0 and img_idx == 0:
                ax.set_title(f"Class {class_idx}")

        # Add domain labels to the first image in each row
        if class_idx == 0:
            ax = axes[domain_idx, 0]
            ax.text(-30, 32, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)
            #ax.text(-150, images[0].size[1] // 2, f"Domain {domain_idx}", rotation=90, va='center', fontsize=12)

plt.tight_layout()
plt.savefig("slides_per_domain_class.png")

Or very straightforward and then inspect:

import os
from wilds import get_dataset

def save_images():
    # Create the 'images' directory if it doesn't exist
    if not os.path.exists('images'):
        os.makedirs('images')

    # Load the camelyon17 dataset
    dataset = get_dataset(dataset='camelyon17', download=True)
    
    # Get the validation and test subsets
    val_data = dataset.get_subset('val')

    # Save the first 10 images from the validation set
    for i in range(10):
        x, y, metadata = val_data[i]
        x.save('images/val{}.png'.format(i+1))

if __name__ == '__main__':
    save_images()

Cheers David

Originally posted by @David-Drexlin in #163

@David-Drexlin David-Drexlin changed the title Hi everyone, VAL_CENTER Indice in camelyon17_dataset.py Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant