From 0241583fd116e8dc22fac3db5b43eaa02ea759ec Mon Sep 17 00:00:00 2001 From: Jiayuan Zhu Date: Mon, 22 Jan 2024 18:33:11 +0000 Subject: [PATCH] add LIDC datasetloader, update environment, modify random_click func --- dataset.py | 95 +++++++++++++++-- environment.yml | 276 +++++++++++++++++++++++++----------------------- train.py | 18 +++- utils.py | 11 +- 4 files changed, 256 insertions(+), 144 deletions(-) diff --git a/dataset.py b/dataset.py index ae794301..ddd4ceff 100644 --- a/dataset.py +++ b/dataset.py @@ -47,7 +47,6 @@ def __getitem__(self, index): # else: # inout = 1 # point_label = 1 - inout = 1 point_label = 1 """Get the images""" @@ -69,7 +68,7 @@ def __getitem__(self, index): mask = mask.resize(newsize) if self.prompt == 'click': - pt = random_click(np.array(mask) / 255, point_label, inout) + point_label, pt = random_click(np.array(mask) / 255, point_label) if self.transform: state = torch.get_rng_state() @@ -110,7 +109,6 @@ def __len__(self): return len(self.subfolders) def __getitem__(self, index): - inout = 1 point_label = 1 """Get the images""" @@ -132,10 +130,10 @@ def __getitem__(self, index): multi_rater_cup_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_cup] multi_rater_disc_np = [np.array(single_rater.resize(newsize)) for single_rater in multi_rater_disc] - # first click is the target agreement among all raters + # first click is the target agreement among most raters if self.prompt == 'click': - pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label, inout) - pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label, inout) + point_label, pt_cup = random_click(np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255, point_label) + point_label, pt_disc = random_click(np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255, point_label) if self.transform: state = torch.get_rng_state() @@ -153,16 +151,93 @@ def __getitem__(self, index): image_meta_dict = {'filename_or_obj':name} return { 'image':img, - 'multi_rater_cup': multi_rater_cup, + 'multi_rater': multi_rater_cup, 'multi_rater_disc': multi_rater_disc, 'mask_cup': mask_cup, 'mask_disc': mask_disc, - 'label': mask_disc, + 'label': mask_cup, 'p_label':point_label, 'pt_cup':pt_cup, 'pt_disc':pt_disc, - 'pt':pt_disc, - 'selected_rater': torch.tensor(np.arange(7)), + 'pt':pt_cup, + 'image_meta_dict':image_meta_dict, + } + + +class LIDC(Dataset): + names = [] + images = [] + labels = [] + series_uid = [] + + def __init__(self, data_path, transform=None, transform_msk = None, prompt = 'click'): + self.prompt = prompt + self.transform = transform + self.transform_msk = transform_msk + + max_bytes = 2**31 - 1 + data = {} + for file in os.listdir(data_path): + filename = os.fsdecode(file) + if '.pickle' in filename: + file_path = data_path + filename + bytes_in = bytearray(0) + input_size = os.path.getsize(file_path) + with open(file_path, 'rb') as f_in: + for _ in range(0, input_size, max_bytes): + bytes_in += f_in.read(max_bytes) + new_data = pickle.loads(bytes_in) + data.update(new_data) + + + for key, value in data.items(): + self.names.append(key) + self.images.append(value['image'].astype(float)) + self.labels.append(value['masks']) + self.series_uid.append(value['series_uid']) + + assert (len(self.images) == len(self.labels) == len(self.series_uid)) + + for img in self.images: + assert np.max(img) <= 1 and np.min(img) >= 0 + for label in self.labels: + assert np.max(label) <= 1 and np.min(label) >= 0 + + del new_data + del data + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + + point_label = 1 + + """Get the images""" + img = np.expand_dims(self.images[index], axis=0) + name = self.names[index] + multi_rater = self.labels[index] + + # first click is the target most agreement among raters, otherwise, background agreement + if self.prompt == 'click': + point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater), axis=0)) / 255, point_label) + + # Convert image (ensure three channels) and multi-rater labels to torch tensors + img = torch.from_numpy(img).type(torch.float32) + img = img.repeat(3, 1, 1) + multi_rater = [torch.from_numpy(single_rater).type(torch.float32) for single_rater in multi_rater] + + multi_rater = torch.stack(multi_rater, dim=0) + multi_rater = multi_rater.unsqueeze(1) + mask = multi_rater.mean(dim=0) # average + + image_meta_dict = {'filename_or_obj':name} + return { + 'image':img, + 'multi_rater': multi_rater, + 'label': mask, + 'p_label':point_label, + 'pt':pt, 'image_meta_dict':image_meta_dict, } diff --git a/environment.yml b/environment.yml index f2747d7c..4f352008 100644 --- a/environment.yml +++ b/environment.yml @@ -20,11 +20,12 @@ dependencies: - brunsli=0.1=h2531618_0 - bzip2=1.0.8=h7b6447c_0 - c-ares=1.19.0=h5eee18b_0 - - ca-certificates=2022.12.7=ha878542_0 + - ca-certificates=2023.11.17=hbcca054_0 - cffi=1.15.1=py310h5eee18b_3 - cfitsio=3.470=h5893167_7 - charls=2.2.0=h2531618_0 - cloudpickle=2.2.1=py310h06a4308_0 + - comm=0.1.4=pyhd8ed1ab_0 - contourpy=1.0.5=py310hdb19cb5_0 - cpuonly=2.0=0 - cryptography=39.0.1=py310h9ce1e76_0 @@ -32,6 +33,8 @@ dependencies: - cytoolz=0.12.0=py310h5eee18b_0 - dask-core=2023.4.1=py310h06a4308_0 - dbus=1.13.18=hb2f20db_0 + - debugpy=1.6.7=py310h6a678d5_0 + - decorator=5.1.1=pyhd8ed1ab_0 - expat=2.4.9=h6a678d5_0 - ffmpeg=4.3=hf484d3e_0 - fontconfig=2.14.1=h4c34cd2_2 @@ -51,9 +54,13 @@ dependencies: - imagecodecs=2021.8.26=py310h46e8fbd_2 - imageio=2.26.0=py310h06a4308_0 - importlib-metadata=6.0.0=py310h06a4308_0 + - importlib_metadata=6.0.0=hd8ed1ab_0 - intel-openmp=2021.4.0=h06a4308_3561 + - ipykernel=6.26.0=pyhf8b6a83_0 - joblib=1.1.1=py310h06a4308_0 - jpeg=9e=h5eee18b_1 + - jupyter_client=8.6.0=pyhd8ed1ab_0 + - jupyter_core=5.5.0=py310hff52083_0 - jxrlib=1.1=h7b6447c_2 - kiwisolver=1.4.4=py310h6a678d5_0 - krb5=1.19.4=h568e23c_0 @@ -74,6 +81,7 @@ dependencies: - libev=4.33=h7f8727e_1 - libevent=2.1.12=h8f2d780_0 - libffi=3.4.2=h6a678d5_6 + - libgcc=7.2.0=h69d50b8_2 - libgcc-ng=11.2.0=h1234567_1 - libgfortran-ng=11.2.0=h00389a5_1 - libgfortran5=11.2.0=h1234567_1 @@ -85,6 +93,7 @@ dependencies: - libpng=1.6.39=h5eee18b_0 - libpq=12.9=h16c4e8d_3 - libprotobuf=3.20.3=he621ea3_0 + - libsodium=1.0.18=h36c2ea0_1 - libssh2=1.10.0=h8f2d780_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 @@ -104,13 +113,16 @@ dependencies: - markupsafe=2.1.1=py310h7f8727e_0 - matplotlib=3.7.1=py310h06a4308_1 - matplotlib-base=3.7.1=py310h1128e8f_1 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 - mkl=2021.4.0=h06a4308_640 - mkl-service=2.4.0=py310h7f8727e_0 - mkl_fft=1.3.1=py310hd6ae3a3_0 - mkl_random=1.2.2=py310h00e6091_0 - monai=1.1.0=pyhd8ed1ab_0 - multidict=6.0.2=py310h5eee18b_0 + - munkres=1.1.4=pyh9f0ad1d_0 - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.5.8=pyhd8ed1ab_0 - nettle=3.7.3=hbbd107a_1 - networkx=2.8.4=py310h06a4308_1 - nspr=4.33=h295c915_0 @@ -121,24 +133,35 @@ dependencies: - oauthlib=3.2.2=py310h06a4308_0 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.4.0=h3ad879b_0 - - openssl=1.1.1t=h7f8727e_0 + - openssl=1.1.1w=h7f8727e_0 - packaging=23.0=py310h06a4308_0 - pandas=1.5.3=py310h1128e8f_0 + - parso=0.8.3=pyhd8ed1ab_0 - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyh1a96a4e_2 + - pickleshare=0.7.5=py_1003 - pillow=9.4.0=py310h6a678d5_0 - pip=23.0.1=py310h06a4308_0 + - platformdirs=4.1.0=pyhd8ed1ab_0 - ply=3.11=py310h06a4308_0 - protobuf=3.20.3=py310h6a678d5_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pycparser=2.21=pyhd8ed1ab_0 - pyjwt=2.4.0=py310h06a4308_0 - pyopenssl=23.0.0=py310h06a4308_0 - pyparsing=3.0.9=py310h06a4308_0 - pyqt=5.15.7=py310h6a678d5_1 - pysocks=1.7.1=py310h06a4308_0 - python=3.10.11=h7a1cb2a_2 + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python_abi=3.10=2_cp310 - pytorch-mutex=1.0=cpu - pytz=2022.7=py310h06a4308_0 + - pyu2f=0.1.5=pyhd8ed1ab_0 - pywavelets=1.4.1=py310h5eee18b_0 - pyyaml=6.0=py310h5eee18b_1 + - pyzmq=23.0.0=py310h330234f_0 - qt-main=5.15.2=h8373d8f_8 - qt-webengine=5.15.9=hbbf29b9_6 - qtwebkit=5.212=h3fafdc1_5 @@ -151,153 +174,146 @@ dependencies: - seaborn=0.12.2=py310h06a4308_0 - setuptools=66.0.0=py310h06a4308_0 - sip=6.6.2=py310h6a678d5_0 + - six=1.16.0=pyh6c4a22f_0 - snappy=1.1.9=h295c915_0 - sqlite=3.41.2=h5eee18b_0 + - stack_data=0.6.2=pyhd8ed1ab_0 - tensorboard=2.11.0=py310h06a4308_0 - tensorboard-data-server=0.6.1=py310h52d8a92_0 - tensorboard-plugin-wit=1.8.1=py310h06a4308_0 - tk=8.6.12=h1ccaba5_0 + - toml=0.10.2=pyhd8ed1ab_0 - toolz=0.12.0=py310h06a4308_0 - torchaudio=0.12.1=py310_cpu - tornado=6.2=py310h5eee18b_0 - tqdm=4.65.0=py310h2f386ee_0 - typing_extensions=4.5.0=py310h06a4308_0 - tzdata=2023c=h04d1e81_0 + - unicodedata2=14.0.0=py310h5764c6d_1 - urllib3=1.26.15=py310h06a4308_0 - wheel=0.38.4=py310h06a4308_0 - xz=5.4.2=h5eee18b_0 - yaml=0.2.5=h7b6447c_0 - yarl=1.8.1=py310h5eee18b_0 + - zeromq=4.3.4=h9c3ff4c_1 - zfp=0.5.5=h295c915_6 - zipp=3.11.0=py310h06a4308_0 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.5=hc292b87_0 - pip: - - --extra-index-url https://download.pytorch.org/whl/cu113 - - aiosignal==1.2.0 - - alembic==1.10.4 - - appdirs==1.4.4 - - astor==0.8.1 - - asttokens==2.2.1 - - backcall==0.2.0 - - beautifulsoup4==4.12.2 - - blinker==1.6.2 - - cachetools==4.2.2 - - certifi==2022.12.7 - - charset-normalizer==2.0.4 - - click==8.1.3 - - cmaes==0.9.1 - - colorama==0.4.6 - - colorlog==6.7.0 - - contextlib2==21.6.0 - - coverage==6.5.0 - - coveralls==3.3.1 - - cucim==23.4.1 - - cycler==0.11.0 - - databricks-cli==0.17.7 - - decorator==5.1.1 - - docker==6.1.1 - - docopt==0.6.2 - - einops==0.6.1 - - entrypoints==0.4 - - exceptiongroup==1.1.1 - - executing==1.2.0 - - filelock==3.12.0 - - fire==0.5.0 - - flask==2.3.2 - - fonttools==4.25.0 - - future==0.18.3 - - gdown==4.7.1 - - gitdb==4.0.10 - - gitpython==3.1.31 - - google-auth==2.6.0 - - google-auth-oauthlib==0.4.4 - - greenlet==2.0.2 - - gunicorn==20.1.0 - - h5py==3.8.0 - - huggingface-hub==0.14.1 - - iniconfig==2.0.0 - - ipython==8.13.1 - - itk==5.3.0 - - itk-core==5.3.0 - - itk-filtering==5.3.0 - - itk-io==5.3.0 - - itk-numerics==5.3.0 - - itk-registration==5.3.0 - - itk-segmentation==5.3.0 - - itsdangerous==2.1.2 - - jedi==0.18.2 - - jinja2==3.1.2 - - json-tricks==3.16.1 - - jsonschema==4.17.3 - - kornia==0.4.1 - - lmdb==1.4.1 - - lucent==0.1.0 - - mako==1.2.4 - - matplotlib-inline==0.1.6 - - mlflow==2.3.1 - - munkres==1.1.4 - - nibabel==5.1.0 - - ninja==1.11.1 - - nni==2.10 - - nptyping==2.5.0 - - opencv-python==4.7.0.72 - - openslide-python==1.1.2 - - optuna==3.1.1 - - parso==0.8.3 - - partd==1.2.0 - - pexpect==4.8.0 - - pickleshare==0.7.5 - - pluggy==1.0.0 - - pooch==1.4.0 - - prettytable==3.7.0 - - prompt-toolkit==3.0.38 - - psutil==5.9.5 - - ptyprocess==0.7.0 - - pure-eval==0.2.2 - - pyarrow==11.0.0 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pycparser==2.21 - - pydicom==2.3.1 - - pygments==2.15.1 - - pynrrd==1.0.0 - - pyqt5-sip==12.11.0 - - pyrsistent==0.19.3 - - pytest==7.3.1 - - pytest-mock==3.10.0 - - python-dateutil==2.8.2 - - pythonwebhdfs==0.2.3 - - pytorch-ignite==0.4.10 - - querystring-parser==1.2.4 - - regex==2023.5.5 - - requests-oauthlib==1.3.0 - - responses==0.23.1 - - rsa==4.7.2 - - schema==0.7.5 - - simplejson==3.19.1 - - six==1.16.0 - - smmap==5.0.0 - - soupsieve==2.4.1 - - sqlalchemy==2.0.12 - - sqlparse==0.4.4 - - stack-data==0.6.2 - - tabulate==0.9.0 - - tensorboardx==2.2 - - termcolor==2.3.0 - - threadpoolctl==2.2.0 - - tifffile==2021.7.2 - - tokenizers==0.12.1 - - toml==0.10.2 - - tomli==2.0.1 - - torch==1.12.1+cu113 - - torch-lucent==0.1.8 - - torchvision==0.13.1+cu113 - - traitlets==5.9.0 - - transformers==4.21.3 - - typeguard==3.0.2 - - types-pyyaml==6.0.12.9 - - wcwidth==0.2.6 - - websocket-client==1.5.1 - - websockets==11.0.3 - - werkzeug==2.3.4 + - aiosignal==1.2.0 + - alembic==1.10.4 + - appdirs==1.4.4 + - astor==0.8.1 + - asttokens==2.2.1 + - backcall==0.2.0 + - beautifulsoup4==4.12.2 + - blinker==1.6.2 + - cachetools==4.2.2 + - certifi==2022.12.7 + - charset-normalizer==2.0.4 + - click==8.1.3 + - cmaes==0.9.1 + - colorama==0.4.6 + - colorlog==6.7.0 + - contextlib2==21.6.0 + - coverage==6.5.0 + - coveralls==3.3.1 + - cucim==23.4.1 + - cycler==0.11.0 + - databricks-cli==0.17.7 + - docker==6.1.1 + - docopt==0.6.2 + - einops==0.6.1 + - entrypoints==0.4 + - exceptiongroup==1.1.1 + - executing==1.2.0 + - filelock==3.12.0 + - fire==0.5.0 + - flask==2.3.2 + - fonttools==4.25.0 + - future==0.18.3 + - gdown==4.7.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - google-auth==2.6.0 + - google-auth-oauthlib==0.4.4 + - greenlet==2.0.2 + - gunicorn==20.1.0 + - h5py==3.8.0 + - huggingface-hub==0.14.1 + - iniconfig==2.0.0 + - ipython==8.13.1 + - itk==5.3.0 + - itk-core==5.3.0 + - itk-filtering==5.3.0 + - itk-io==5.3.0 + - itk-numerics==5.3.0 + - itk-registration==5.3.0 + - itk-segmentation==5.3.0 + - itsdangerous==2.1.2 + - jedi==0.18.2 + - jinja2==3.1.2 + - json-tricks==3.16.1 + - jsonschema==4.17.3 + - kornia==0.4.1 + - lmdb==1.4.1 + - lucent==0.1.0 + - mako==1.2.4 + - mlflow==2.3.1 + - nibabel==5.1.0 + - ninja==1.11.1 + - nni==2.10 + - nptyping==2.5.0 + - opencv-python==4.7.0.72 + - openslide-python==1.1.2 + - optuna==3.1.1 + - partd==1.2.0 + - pluggy==1.0.0 + - pooch==1.4.0 + - prettytable==3.7.0 + - prompt-toolkit==3.0.38 + - psutil==5.9.5 + - pyarrow==11.0.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pydicom==2.3.1 + - pygments==2.15.1 + - pynrrd==1.0.0 + - pyqt5-sip==12.11.0 + - pyrsistent==0.19.3 + - pytest==7.3.1 + - pytest-mock==3.10.0 + - pythonwebhdfs==0.2.3 + - pytorch-ignite==0.4.10 + - querystring-parser==1.2.4 + - regex==2023.5.5 + - requests-oauthlib==1.3.0 + - responses==0.23.1 + - rsa==4.7.2 + - safetensors==0.4.1 + - schema==0.7.5 + - simplejson==3.19.1 + - smmap==5.0.0 + - soupsieve==2.4.1 + - sqlalchemy==2.0.12 + - sqlparse==0.4.4 + - tabulate==0.9.0 + - tensorboardx==2.2 + - termcolor==2.3.0 + - threadpoolctl==2.2.0 + - tifffile==2021.7.2 + - timm==0.9.12 + - tokenizers==0.12.1 + - tomli==2.0.1 + - torch==1.12.1+cu113 + - torch-lucent==0.1.8 + - torchvision==0.13.1+cu113 + - traitlets==5.9.0 + - transformers==4.21.3 + - typeguard==3.0.2 + - types-pyyaml==6.0.12.9 + - wcwidth==0.2.6 + - websocket-client==1.5.1 + - websockets==11.0.3 + - werkzeug==2.3.4 diff --git a/train.py b/train.py index 7cbb591f..289fd005 100644 --- a/train.py +++ b/train.py @@ -25,6 +25,7 @@ #from dataset import * from torch.autograd import Variable from torch.utils.data import DataLoader, random_split +from torch.utils.data.sampler import SubsetRandomSampler from tqdm import tqdm import cfg @@ -108,11 +109,26 @@ '''REFUGE data''' refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') - + nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' +elif args.dataset == 'LIDC': + '''LIDC data''' + dataset = LIDC(data_path = args.data_path) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(0.2 * dataset_size)) + np.random.shuffle(indices) + train_sampler = SubsetRandomSampler(indices[split:]) + test_sampler = SubsetRandomSampler(indices[:split]) + + nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) + nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) + '''end''' + '''checkpoint path and tensorboard''' # iter_per_epoch = len(Glaucoma_training_loader) diff --git a/utils.py b/utils.py index 406c9820..edb35f2e 100644 --- a/utils.py +++ b/utils.py @@ -1114,9 +1114,14 @@ def calculate_gradient_penalty(netD, real_images, fake_images): return grad_penalty -def random_click(mask, point_labels = 1, inout = 1): - indices = np.argwhere(mask == inout) - return indices[np.random.randint(len(indices))] +def random_click(mask, point_labels = 1): + # check if all masks are black + max_label = max(set(mask.flatten())) + if max_label == 0: + point_labels = max_label + # max agreement position + indices = np.argwhere(mask == max_label) + return point_labels, indices[np.random.randint(len(indices))] def generate_click_prompt(img, msk, pt_label = 1):