From 15c4e16a0e7f760d485e9cd3db886c295371266d Mon Sep 17 00:00:00 2001 From: ShrihanSolo Date: Fri, 21 Jun 2024 01:51:01 +0000 Subject: [PATCH] restructuring and preparing 3band --- data | 1 + {src/sim => sim}/.DS_Store | Bin {src/sim => sim}/configs/delve_galgal.yaml | 0 {src/sim => sim}/configs/delve_shrihan.yaml | 0 .../configs/fid_source.yaml | 2 +- .../configs/fid_target.yaml | 2 +- .../sim => sim}/configs/multiband_source.yaml | 2 +- .../sim => sim}/configs/multiband_target.yaml | 2 +- {src/sim => sim}/configs/pax_source.yaml | 0 {src/sim => sim}/configs/pax_source_des.yaml | 0 {src/sim => sim}/configs/pax_target.yaml | 0 {src/sim => sim}/configs/sky1.yaml | 0 {src/sim => sim}/configs/sky100.yaml | 0 {src/sim => sim}/configs/sky10000.yaml | 0 {src/sim => sim}/configs/test.yaml | 0 {notebooks => sim/notebooks}/gen_sim.ipynb | 60 +- .../notebooks}/old/dlsim_Shrihan.ipynb | 0 .../notebooks}/old/example.ipynb | 0 .../notebooks}/old/explore.ipynb | 0 .../notebooks}/old/paxsim_Shrihan.ipynb | 0 .../old/sim_training_set_cleaned.ipynb | 0 .../old/sky_brightness_testing.ipynb | 0 .../notebooks}/old/testsim_Shrihan.ipynb | 0 {notebooks => sim/notebooks}/renorm.ipynb | 0 {src => sim}/scripts/__init__.py | 0 {src => sim}/scripts/__version__.py | 0 {src => sim}/scripts/evaluate.py | 0 {src => sim}/scripts/paths.py | 0 {src => sim}/scripts/train.py | 0 {src => sim}/static/.gitignore | 0 {src => sim}/tex/.gitignore | 0 {src => sim}/tex/bib.bib | 0 {src => sim}/tex/figures/.gitignore | 0 {src => sim}/tex/ms.tex | 0 {src => sim}/tex/output/.gitignore | 0 {src => sim}/tex/showyourwork.sty | 0 src/sim/data/.DS_Store | Bin 6148 -> 0 bytes src/sim/data/.gitignore | 2 - src/sim/data/mb_target/.DS_Store | Bin 6148 -> 0 bytes test/test_example.py | 23 - {src => training}/.DS_Store | Bin .../MMD_old}/Shrihan_MMD_practice.ipynb | 0 .../MMD_old}/Shrihan_Norm2_MMD_practice.ipynb | 0 .../MMD_old}/Shrihan_Norm_MMD_practice.ipynb | 0 .../MMD_paper/fiducial}/ShrihanPaperMMD.ipynb | 0 .../fiducial}/ShrihanPaperMMD_fidcheck.ipynb | 0 .../fiducial/ShrihanPaperMMD_fidcheck2.ipynb | 1380 +++++++++++++++++ .../multiband/ShrihanPaperMMD_mb.ipynb | 1005 ++++++++++++ .../ShrihanPaperMMD_MinMaxNorm.ipynb | 0 .../normalization}/ShrihanPaperMMD_Norm.ipynb | 0 .../ShrihanPaperMMD_Norm2.ipynb | 0 .../original_mmdpaper_notebook.ipynb | 0 training/scripts/__init__.py | 0 training/scripts/__version__.py | 0 training/scripts/evaluate.py | 46 + training/scripts/paths.py | 29 + training/scripts/train.py | 39 + training/static/.gitignore | 2 + training/tex/.gitignore | 2 + training/tex/bib.bib | 37 + training/tex/figures/.gitignore | 5 + training/tex/ms.tex | 254 +++ training/tex/output/.gitignore | 5 + training/tex/showyourwork.sty | 13 + 64 files changed, 2869 insertions(+), 42 deletions(-) create mode 120000 data rename {src/sim => sim}/.DS_Store (100%) rename {src/sim => sim}/configs/delve_galgal.yaml (100%) rename {src/sim => sim}/configs/delve_shrihan.yaml (100%) rename src/sim/configs/pax_original_source.yaml => sim/configs/fid_source.yaml (97%) rename src/sim/configs/pax_original_target.yaml => sim/configs/fid_target.yaml (99%) rename {src/sim => sim}/configs/multiband_source.yaml (97%) rename {src/sim => sim}/configs/multiband_target.yaml (99%) rename {src/sim => sim}/configs/pax_source.yaml (100%) rename {src/sim => sim}/configs/pax_source_des.yaml (100%) rename {src/sim => sim}/configs/pax_target.yaml (100%) rename {src/sim => sim}/configs/sky1.yaml (100%) rename {src/sim => sim}/configs/sky100.yaml (100%) rename {src/sim => sim}/configs/sky10000.yaml (100%) rename {src/sim => sim}/configs/test.yaml (100%) rename {notebooks => sim/notebooks}/gen_sim.ipynb (51%) rename {notebooks => sim/notebooks}/old/dlsim_Shrihan.ipynb (100%) rename {notebooks => sim/notebooks}/old/example.ipynb (100%) rename {notebooks => sim/notebooks}/old/explore.ipynb (100%) rename {notebooks => sim/notebooks}/old/paxsim_Shrihan.ipynb (100%) rename {notebooks => sim/notebooks}/old/sim_training_set_cleaned.ipynb (100%) rename {notebooks => sim/notebooks}/old/sky_brightness_testing.ipynb (100%) rename {notebooks => sim/notebooks}/old/testsim_Shrihan.ipynb (100%) rename {notebooks => sim/notebooks}/renorm.ipynb (100%) rename {src => sim}/scripts/__init__.py (100%) rename {src => sim}/scripts/__version__.py (100%) rename {src => sim}/scripts/evaluate.py (100%) rename {src => sim}/scripts/paths.py (100%) rename {src => sim}/scripts/train.py (100%) rename {src => sim}/static/.gitignore (100%) rename {src => sim}/tex/.gitignore (100%) rename {src => sim}/tex/bib.bib (100%) rename {src => sim}/tex/figures/.gitignore (100%) rename {src => sim}/tex/ms.tex (100%) rename {src => sim}/tex/output/.gitignore (100%) rename {src => sim}/tex/showyourwork.sty (100%) delete mode 100644 src/sim/data/.DS_Store delete mode 100644 src/sim/data/.gitignore delete mode 100644 src/sim/data/mb_target/.DS_Store delete mode 100644 test/test_example.py rename {src => training}/.DS_Store (100%) rename {notebooks => training/notebooks/MMD_old}/Shrihan_MMD_practice.ipynb (100%) rename {notebooks => training/notebooks/MMD_old}/Shrihan_Norm2_MMD_practice.ipynb (100%) rename {notebooks => training/notebooks/MMD_old}/Shrihan_Norm_MMD_practice.ipynb (100%) rename {notebooks => training/notebooks/MMD_paper/fiducial}/ShrihanPaperMMD.ipynb (100%) rename {notebooks => training/notebooks/MMD_paper/fiducial}/ShrihanPaperMMD_fidcheck.ipynb (100%) create mode 100644 training/notebooks/MMD_paper/fiducial/ShrihanPaperMMD_fidcheck2.ipynb create mode 100644 training/notebooks/MMD_paper/multiband/ShrihanPaperMMD_mb.ipynb rename {notebooks => training/notebooks/MMD_paper/normalization}/ShrihanPaperMMD_MinMaxNorm.ipynb (100%) rename {notebooks => training/notebooks/MMD_paper/normalization}/ShrihanPaperMMD_Norm.ipynb (100%) rename {notebooks => training/notebooks/MMD_paper/normalization}/ShrihanPaperMMD_Norm2.ipynb (100%) rename notebooks/mmd_to_send.ipynb => training/notebooks/MMD_paper/original_mmdpaper_notebook.ipynb (100%) create mode 100644 training/scripts/__init__.py create mode 100644 training/scripts/__version__.py create mode 100644 training/scripts/evaluate.py create mode 100644 training/scripts/paths.py create mode 100644 training/scripts/train.py create mode 100644 training/static/.gitignore create mode 100644 training/tex/.gitignore create mode 100644 training/tex/bib.bib create mode 100644 training/tex/figures/.gitignore create mode 100644 training/tex/ms.tex create mode 100644 training/tex/output/.gitignore create mode 100644 training/tex/showyourwork.sty diff --git a/data b/data new file mode 120000 index 0000000..b0fe5e7 --- /dev/null +++ b/data @@ -0,0 +1 @@ +/deepskieslab/agarwal/data \ No newline at end of file diff --git a/src/sim/.DS_Store b/sim/.DS_Store similarity index 100% rename from src/sim/.DS_Store rename to sim/.DS_Store diff --git a/src/sim/configs/delve_galgal.yaml b/sim/configs/delve_galgal.yaml similarity index 100% rename from src/sim/configs/delve_galgal.yaml rename to sim/configs/delve_galgal.yaml diff --git a/src/sim/configs/delve_shrihan.yaml b/sim/configs/delve_shrihan.yaml similarity index 100% rename from src/sim/configs/delve_shrihan.yaml rename to sim/configs/delve_shrihan.yaml diff --git a/src/sim/configs/pax_original_source.yaml b/sim/configs/fid_source.yaml similarity index 97% rename from src/sim/configs/pax_original_source.yaml rename to sim/configs/fid_source.yaml index 9bdaae0..7cf0a94 100644 --- a/src/sim/configs/pax_original_source.yaml +++ b/sim/configs/fid_source.yaml @@ -2,7 +2,7 @@ DATASET: NAME: SourceData # set a name, this value is only used if you request the h5 file format PARAMETERS: SIZE: 50000 # number of images in the full datase. - OUTDIR: ../src/sim/data/pax_orig_source # will be created on your system if your request to save images + OUTDIR: ../../data/fid_source # will be created on your system if your request to save images SEED: 10 COSMOLOGY: diff --git a/src/sim/configs/pax_original_target.yaml b/sim/configs/fid_target.yaml similarity index 99% rename from src/sim/configs/pax_original_target.yaml rename to sim/configs/fid_target.yaml index d563752..ae7bda0 100644 --- a/src/sim/configs/pax_original_target.yaml +++ b/sim/configs/fid_target.yaml @@ -2,7 +2,7 @@ DATASET: NAME: TargetData # set a name, this value is only used if you request the h5 file format PARAMETERS: SIZE: 50000 # number of images in the full datase. - OUTDIR: pax_orig_target + OUTDIR: ../../data/fid_target SEED: 10 COSMOLOGY: diff --git a/src/sim/configs/multiband_source.yaml b/sim/configs/multiband_source.yaml similarity index 97% rename from src/sim/configs/multiband_source.yaml rename to sim/configs/multiband_source.yaml index 4a59345..9fadd7f 100644 --- a/src/sim/configs/multiband_source.yaml +++ b/sim/configs/multiband_source.yaml @@ -2,7 +2,7 @@ DATASET: NAME: SourceData # set a name, this value is only used if you request the h5 file format PARAMETERS: SIZE: 50000 # number of images in the full datase. - OUTDIR: ../src/sim/data/mb_source # will be created on your system if your request to save images + OUTDIR: ../../data/mb_source # will be created on your system if your request to save images SEED: 10 COSMOLOGY: diff --git a/src/sim/configs/multiband_target.yaml b/sim/configs/multiband_target.yaml similarity index 99% rename from src/sim/configs/multiband_target.yaml rename to sim/configs/multiband_target.yaml index 21617b7..2da4eed 100644 --- a/src/sim/configs/multiband_target.yaml +++ b/sim/configs/multiband_target.yaml @@ -2,7 +2,7 @@ DATASET: NAME: TargetData # set a name, this value is only used if you request the h5 file format PARAMETERS: SIZE: 50000 # number of images in the full datase. - OUTDIR: ../src/sim/data/mb_target + OUTDIR: ../../data/mb_target SEED: 10 COSMOLOGY: diff --git a/src/sim/configs/pax_source.yaml b/sim/configs/pax_source.yaml similarity index 100% rename from src/sim/configs/pax_source.yaml rename to sim/configs/pax_source.yaml diff --git a/src/sim/configs/pax_source_des.yaml b/sim/configs/pax_source_des.yaml similarity index 100% rename from src/sim/configs/pax_source_des.yaml rename to sim/configs/pax_source_des.yaml diff --git a/src/sim/configs/pax_target.yaml b/sim/configs/pax_target.yaml similarity index 100% rename from src/sim/configs/pax_target.yaml rename to sim/configs/pax_target.yaml diff --git a/src/sim/configs/sky1.yaml b/sim/configs/sky1.yaml similarity index 100% rename from src/sim/configs/sky1.yaml rename to sim/configs/sky1.yaml diff --git a/src/sim/configs/sky100.yaml b/sim/configs/sky100.yaml similarity index 100% rename from src/sim/configs/sky100.yaml rename to sim/configs/sky100.yaml diff --git a/src/sim/configs/sky10000.yaml b/sim/configs/sky10000.yaml similarity index 100% rename from src/sim/configs/sky10000.yaml rename to sim/configs/sky10000.yaml diff --git a/src/sim/configs/test.yaml b/sim/configs/test.yaml similarity index 100% rename from src/sim/configs/test.yaml rename to sim/configs/test.yaml diff --git a/notebooks/gen_sim.ipynb b/sim/notebooks/gen_sim.ipynb similarity index 51% rename from notebooks/gen_sim.ipynb rename to sim/notebooks/gen_sim.ipynb index f37fe03..5bf9e1b 100644 --- a/notebooks/gen_sim.ipynb +++ b/sim/notebooks/gen_sim.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,23 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_result(name, head):\n", + " datapath = head / f'data/{name}'\n", + " imgpath = datapath / 'CONFIGURATION_1_images.npy'\n", + " mdpath = datapath / 'CONFIGURATION_1_metadata.csv'\n", + " if imgpath.exists():\n", + " imgpath.rename(datapath / (datapath.name + '.npy'))\n", + " if mdpath.exists():\n", + " mdpath.rename(datapath / (datapath.name + '_metadata.csv'))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -22,19 +38,28 @@ "Entering main organization loop\n", "Organizing CONFIGURATION_1\n", "Generating images for CONFIGURATION_1\n", - "\tProgress: 100.0 % --- Elapsed Time: 0 H 8 M 46 S \n" + "\tProgress: 100.0 % --- Elapsed Time: 0 H 1 M 58 S \n" ] } ], "source": [ - "head = Path.cwd().parent\n", - "config_file = head / 'src/sim/configs/multiband_source.yaml'\n", + "head = Path.cwd().parent.parent\n", + "config_file = head / 'sim/configs/fid_source.yaml'\n", "dataset = dl.make_dataset(config_file, verbose=True, save_to_disk=True)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "rename_result('fid_source', head)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -44,16 +69,25 @@ "Entering main organization loop\n", "Organizing CONFIGURATION_1\n", "Generating images for CONFIGURATION_1\n", - "\tProgress: 100.0 % --- Elapsed Time: 0 H 8 M 20 S \n" + "\tProgress: 100.0 % --- Elapsed Time: 0 H 1 M 58 S \n" ] } ], "source": [ - "head = Path.cwd().parent\n", - "config_file = head / 'src/sim/configs/multiband_target.yaml'\n", + "head = Path.cwd().parent.parent\n", + "config_file = head / 'sim/configs/fid_target.yaml'\n", "dataset = dl.make_dataset(config_file, verbose=True, save_to_disk=True)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "rename_result('fid_target', head)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -64,9 +98,9 @@ ], "metadata": { "kernelspec": { - "display_name": "deeplens", + "display_name": "Python [conda env:.conda-deeplens]", "language": "python", - "name": "python3" + "name": "conda-env-.conda-deeplens-py" }, "language_info": { "codemirror_mode": { @@ -78,9 +112,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.16" + "version": "3.7.12" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/notebooks/old/dlsim_Shrihan.ipynb b/sim/notebooks/old/dlsim_Shrihan.ipynb similarity index 100% rename from notebooks/old/dlsim_Shrihan.ipynb rename to sim/notebooks/old/dlsim_Shrihan.ipynb diff --git a/notebooks/old/example.ipynb b/sim/notebooks/old/example.ipynb similarity index 100% rename from notebooks/old/example.ipynb rename to sim/notebooks/old/example.ipynb diff --git a/notebooks/old/explore.ipynb b/sim/notebooks/old/explore.ipynb similarity index 100% rename from notebooks/old/explore.ipynb rename to sim/notebooks/old/explore.ipynb diff --git a/notebooks/old/paxsim_Shrihan.ipynb b/sim/notebooks/old/paxsim_Shrihan.ipynb similarity index 100% rename from notebooks/old/paxsim_Shrihan.ipynb rename to sim/notebooks/old/paxsim_Shrihan.ipynb diff --git a/notebooks/old/sim_training_set_cleaned.ipynb b/sim/notebooks/old/sim_training_set_cleaned.ipynb similarity index 100% rename from notebooks/old/sim_training_set_cleaned.ipynb rename to sim/notebooks/old/sim_training_set_cleaned.ipynb diff --git a/notebooks/old/sky_brightness_testing.ipynb b/sim/notebooks/old/sky_brightness_testing.ipynb similarity index 100% rename from notebooks/old/sky_brightness_testing.ipynb rename to sim/notebooks/old/sky_brightness_testing.ipynb diff --git a/notebooks/old/testsim_Shrihan.ipynb b/sim/notebooks/old/testsim_Shrihan.ipynb similarity index 100% rename from notebooks/old/testsim_Shrihan.ipynb rename to sim/notebooks/old/testsim_Shrihan.ipynb diff --git a/notebooks/renorm.ipynb b/sim/notebooks/renorm.ipynb similarity index 100% rename from notebooks/renorm.ipynb rename to sim/notebooks/renorm.ipynb diff --git a/src/scripts/__init__.py b/sim/scripts/__init__.py similarity index 100% rename from src/scripts/__init__.py rename to sim/scripts/__init__.py diff --git a/src/scripts/__version__.py b/sim/scripts/__version__.py similarity index 100% rename from src/scripts/__version__.py rename to sim/scripts/__version__.py diff --git a/src/scripts/evaluate.py b/sim/scripts/evaluate.py similarity index 100% rename from src/scripts/evaluate.py rename to sim/scripts/evaluate.py diff --git a/src/scripts/paths.py b/sim/scripts/paths.py similarity index 100% rename from src/scripts/paths.py rename to sim/scripts/paths.py diff --git a/src/scripts/train.py b/sim/scripts/train.py similarity index 100% rename from src/scripts/train.py rename to sim/scripts/train.py diff --git a/src/static/.gitignore b/sim/static/.gitignore similarity index 100% rename from src/static/.gitignore rename to sim/static/.gitignore diff --git a/src/tex/.gitignore b/sim/tex/.gitignore similarity index 100% rename from src/tex/.gitignore rename to sim/tex/.gitignore diff --git a/src/tex/bib.bib b/sim/tex/bib.bib similarity index 100% rename from src/tex/bib.bib rename to sim/tex/bib.bib diff --git a/src/tex/figures/.gitignore b/sim/tex/figures/.gitignore similarity index 100% rename from src/tex/figures/.gitignore rename to sim/tex/figures/.gitignore diff --git a/src/tex/ms.tex b/sim/tex/ms.tex similarity index 100% rename from src/tex/ms.tex rename to sim/tex/ms.tex diff --git a/src/tex/output/.gitignore b/sim/tex/output/.gitignore similarity index 100% rename from src/tex/output/.gitignore rename to sim/tex/output/.gitignore diff --git a/src/tex/showyourwork.sty b/sim/tex/showyourwork.sty similarity index 100% rename from src/tex/showyourwork.sty rename to sim/tex/showyourwork.sty diff --git a/src/sim/data/.DS_Store b/src/sim/data/.DS_Store deleted file mode 100644 index 23809ba40045923db0899fd9fc9866ac7de1d97e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%TB{E5S)b`3P`9&j{XJyAgan2@Bu)R!lfvPv=T>uIx}7aiO3a%(5_^UAG7Ot zq}W~nwmx>(zzV>UMQCcwn1*MkPAYjpv=ZY1uXw@^hsSOv&|e(Vx1Zq|TioG*(f1GS zw(s}c7S8d)i8m4k^f+>hEq2||?nTg^F6RiOcT{2+6Tc#jmKqZjEi4oW1ww&PAQbpn z1$4R5mKTmuhXSEMDDbI(&WFSzY$|5Qa&)jNDF9K==wfUumyl1Y*i_7poS~V!65Z8m ziec`~@l<(D#q8+ru$p{WUHKzH|KISJnN9MiF3}4G zLV9C8@78bI(_Ncb?pb8wWi@DQPaXm6=p4B!lRlr+CaH1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize source data\n", + "visualize_data(source_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6d6e4147-ce23-4fca-b1aa-42122b0e2501", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 673 + }, + "executionInfo": { + "elapsed": 665, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "6d6e4147-ce23-4fca-b1aa-42122b0e2501", + "outputId": "eccb0d95-4566-445f-a058-b1d5b87765b0" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize target data\n", + "visualize_data(target_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7b706147-6d5c-4319-a7b0-87decc1e6a7f", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "7b706147-6d5c-4319-a7b0-87decc1e6a7f" + }, + "outputs": [], + "source": [ + "# Define and initialize model\n", + "class NeuralNetwork(nn.Module):\n", + " def __init__(self):\n", + " super(NeuralNetwork, self).__init__()\n", + " self.feature = nn.Sequential()\n", + " self.feature.add_module('f_conv1', nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu1', nn.ReLU(True))\n", + " self.feature.add_module('f_bn1', nn.BatchNorm2d(8))\n", + " self.feature.add_module('f_pool1', nn.MaxPool2d(kernel_size=2, stride=2))\n", + " self.feature.add_module('f_conv2', nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu2', nn.ReLU(True))\n", + " self.feature.add_module('f_bn2', nn.BatchNorm2d(16))\n", + " self.feature.add_module('f_pool2', nn.MaxPool2d(kernel_size=2, stride=2))\n", + " self.feature.add_module('f_conv3', nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu3', nn.ReLU(True))\n", + " self.feature.add_module('f_bn3', nn.BatchNorm2d(32))\n", + " self.feature.add_module('f_pool3', nn.MaxPool2d(kernel_size=2, stride=2))\n", + "\n", + " self.regressor = nn.Sequential()\n", + " self.regressor.add_module('r_fc1', nn.Linear(in_features=32*5*5, out_features=128))\n", + " self.regressor.add_module('r_relu1', nn.ReLU(True))\n", + " #self.regressor.add_module('r_fc2', nn.Linear(in_features=128, out_features=64))\n", + " #self.regressor.add_module('r_relu2', nn.ReLU(True))\n", + " self.regressor.add_module('r_fc3', nn.Linear(in_features=128, out_features=1))\n", + "\n", + " def forward(self, x):\n", + " x = x.view(-1, 1, 40, 40)\n", + "\n", + " features = self.feature(x)\n", + " features = features.view(-1, 32*5*5)\n", + " estimate = self.regressor(features)\n", + " estimate = F.relu(estimate)\n", + " estimate = estimate.view(-1)\n", + "\n", + " return estimate, features\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cfd79aed-d467-4d59-a44d-df05177dfd58", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "cfd79aed-d467-4d59-a44d-df05177dfd58" + }, + "outputs": [], + "source": [ + "# code from https://github.com/ZongxianLee/MMD_Loss.Pytorch\n", + "\n", + "class MMD_loss(nn.Module):\n", + " def __init__(self, kernel_mul = 2.0, kernel_num = 5):\n", + " super(MMD_loss, self).__init__()\n", + " self.kernel_num = kernel_num\n", + " self.kernel_mul = kernel_mul\n", + " self.fix_sigma = None\n", + " return\n", + " def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n", + " n_samples = int(source.size()[0])+int(target.size()[0])\n", + " total = torch.cat([source, target], dim=0)\n", + "\n", + " total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", + " total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", + " L2_distance = ((total0-total1)**2).sum(2)\n", + " if fix_sigma:\n", + " bandwidth = fix_sigma\n", + " else:\n", + " bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)\n", + " bandwidth /= kernel_mul ** (kernel_num // 2)\n", + " bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]\n", + " kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]\n", + " return sum(kernel_val)\n", + "\n", + " def forward(self, source, target):\n", + " batch_size = int(source.size()[0])\n", + " kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)\n", + " XX = kernels[:batch_size, :batch_size]\n", + " YY = kernels[batch_size:, batch_size:]\n", + " XY = kernels[:batch_size, batch_size:]\n", + " YX = kernels[batch_size:, :batch_size]\n", + " loss = torch.mean(XX + YY - XY -YX)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ccac040a-7d18-45a4-b390-40e3dfa51756", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750797, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "ccac040a-7d18-45a4-b390-40e3dfa51756" + }, + "outputs": [], + "source": [ + "# Define training loop\n", + "def train_loop(source_dataloader, target_dataloader, model, regressor_loss_fn, da_loss, optimizer, n_epoch, epoch):\n", + "\n", + " domain_error = 0\n", + " domain_classifier_accuracy = 0\n", + " estimator_error = 0\n", + " score_list = np.array([])\n", + "\n", + " len_dataloader = min(len(source_dataloader), len(target_dataloader))\n", + " data_source_iter = iter(source_dataloader)\n", + " data_target_iter = iter(target_dataloader)\n", + "\n", + " i = 0\n", + " while i < len_dataloader:\n", + "\n", + " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", + " alpha = 2. / (1. + np.exp(-10 * p)) - 1\n", + "\n", + " # Source Training\n", + "\n", + " data_source = next(data_source_iter)\n", + " X, y = data_source\n", + " X = X.float()\n", + " X = X.cuda()\n", + " y = y.cuda()\n", + "\n", + " model.zero_grad()\n", + " batch_size = len(y)\n", + "\n", + " domain_label = torch.zeros(batch_size)\n", + " domain_label = domain_label.long()\n", + " domain_label = domain_label.cuda()\n", + "\n", + " estimate_output, domain_output_source = model(X)\n", + "\n", + " estimate_loss = regressor_loss_fn(estimate_output, y)\n", + "\n", + " # Target Training\n", + "\n", + " data_target = next(data_target_iter)\n", + " X_target, _ = data_target\n", + " X_target = X_target.float()\n", + " X_target = X_target.cuda()\n", + "\n", + " batch_size = len(X_target)\n", + "\n", + " _, domain_output_target = model(X_target)\n", + " domain_loss = da_loss(domain_output_source, domain_output_target)\n", + "\n", + " loss = estimate_loss + domain_loss*1.4\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Update values\n", + "\n", + " domain_error += domain_loss.item()\n", + " #domain_classifier_accuracy +=\n", + " estimator_error += estimate_loss.item()\n", + " score = r2_score(y.cpu().detach().numpy(), estimate_output.cpu().detach().numpy())\n", + " score_list = np.append(score_list, score)\n", + "\n", + " i += 1\n", + "\n", + " score = np.mean(score_list)\n", + " domain_error = domain_error / (len_dataloader)\n", + " estimator_error /= len_dataloader\n", + "\n", + " return [domain_error, estimator_error, score]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "98583af6-1fbb-4091-bc22-b1ce362e8f21", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750797, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "98583af6-1fbb-4091-bc22-b1ce362e8f21" + }, + "outputs": [], + "source": [ + "# Define testing loop\n", + "\n", + "def test_loop(source_dataloader, target_dataloader, model, regressor_loss_fn, da_loss, n_epoch, epoch):\n", + "\n", + " with torch.no_grad():\n", + "\n", + " len_dataloader = min(len(source_dataloader), len(target_dataloader))\n", + " data_source_iter = iter(source_dataloader)\n", + " data_target_iter = iter(target_dataloader)\n", + "\n", + " domain_classifier_error = 0\n", + " domain_classifier_accuracy = 0\n", + " estimator_error = 0\n", + " estimator_error_target = 0\n", + " score_list = np.array([])\n", + " score_list_target = np.array([])\n", + "\n", + " i = 0\n", + " while i < len_dataloader:\n", + "\n", + " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", + " alpha = 2. / (1. + np.exp(-10 * p)) - 1\n", + "\n", + " # Source Testing\n", + "\n", + " data_source = next(data_source_iter)\n", + " X, y = data_source\n", + " X = X.float()\n", + " X = X.cuda()\n", + " y = y.cuda()\n", + "\n", + " batch_size = len(y)\n", + "\n", + " #domain_label = torch.zeros(batch_size)\n", + " #domain_label = domain_label.long()\n", + " #domain_label = domain_label.cuda()\n", + "\n", + " estimate_output, domain_output = model(X)\n", + "\n", + " estimate_loss = regressor_loss_fn(estimate_output, y)\n", + " #domain_loss_source = classifier_loss_fn(domain_output, domain_label)\n", + "\n", + " # Target Testing\n", + "\n", + " data_target = next(data_target_iter)\n", + " X_target, y_target = data_target\n", + " X_target = X_target.float()\n", + " X_target = X_target.cuda()\n", + " y_target = y_target.cuda()\n", + "\n", + " batch_size = len(X_target)\n", + "\n", + " #domain_label = torch.ones(batch_size)\n", + " #domain_label = domain_label.long()\n", + " #domain_label = domain_label.cuda()\n", + "\n", + " estimate_output_target, domain_output = model(X_target)\n", + "\n", + " estimate_loss_target = regressor_loss_fn(estimate_output_target, y_target)\n", + " #domain_loss_target = classifier_loss_fn(domain_output, domain_label)\n", + "\n", + " # Update values\n", + "\n", + " # domain_classifier_error += domain_loss_source.item()\n", + " #domain_classifier_error += domain_loss_target.item()\n", + " #domain_classifier_accuracy +=\n", + " estimator_error += estimate_loss.item()\n", + " estimator_error_target += estimate_loss_target.item()\n", + " score = r2_score(y.cpu(), estimate_output.cpu())\n", + " score_list = np.append(score_list, score)\n", + " score_target = r2_score(y_target.cpu(), estimate_output_target.cpu())\n", + " score_list_target = np.append(score_list_target, score_target)\n", + "\n", + " i += 1\n", + "\n", + " score = np.mean(score_list)\n", + " score_target = np.mean(score_list_target)\n", + " #classifier_error = domain_classifier_error / (len_dataloader * 2)\n", + " estimator_error /= len_dataloader\n", + " estimator_error_target /= len_dataloader\n", + " classifier_error = 1\n", + " return [classifier_error, estimator_error, estimator_error_target, score, score_target]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1dfe3810-672c-4a28-b606-b3079a40fca4", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 293833, + "status": "ok", + "timestamp": 1718869045423, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "1dfe3810-672c-4a28-b606-b3079a40fca4", + "outputId": "45493f2a-ea42-401e-f88b-b0ad39b969ed" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1\n", + "-------------------------------\n", + "12.33421277999878\n", + "Train Estimator Error = 0.16444933820188973\n", + "Train Estimator R2 Score = 0.6710\n", + "Train Domain Classifier Error = 0.197300594592879\n", + "Validation Source Estimator Error = 0.03957607594739859\n", + "Validation Source R2 Score = 0.9181\n", + "Validation Target Estimator Error = 0.17865040874595095\n", + "Validation Target R2 Score = 0.6406\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 2\n", + "-------------------------------\n", + "10.286649942398071\n", + "Train Estimator Error = 0.033987110668803534\n", + "Train Estimator R2 Score = 0.9313\n", + "Train Domain Classifier Error = 0.10603604664246277\n", + "Validation Source Estimator Error = 0.026627989835847334\n", + "Validation Source R2 Score = 0.9447\n", + "Validation Target Estimator Error = 0.12391905738100124\n", + "Validation Target R2 Score = 0.7497\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 3\n", + "-------------------------------\n", + "10.679370164871216\n", + "Train Estimator Error = 0.025708429421718748\n", + "Train Estimator R2 Score = 0.9480\n", + "Train Domain Classifier Error = 0.09875815365143406\n", + "Validation Source Estimator Error = 0.025580009335806224\n", + "Validation Source R2 Score = 0.9470\n", + "Validation Target Estimator Error = 0.11177382997836277\n", + "Validation Target R2 Score = 0.7764\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 4\n", + "-------------------------------\n", + "9.528148651123047\n", + "Train Estimator Error = 0.021674147663191916\n", + "Train Estimator R2 Score = 0.9560\n", + "Train Domain Classifier Error = 0.09356177005732953\n", + "Validation Source Estimator Error = 0.023202258696079635\n", + "Validation Source R2 Score = 0.9526\n", + "Validation Target Estimator Error = 0.09558532137874585\n", + "Validation Target R2 Score = 0.8068\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 5\n", + "-------------------------------\n", + "9.20451831817627\n", + "Train Estimator Error = 0.018606798048258863\n", + "Train Estimator R2 Score = 0.9622\n", + "Train Domain Classifier Error = 0.09366841838989659\n", + "Validation Source Estimator Error = 0.016288266745603578\n", + "Validation Source R2 Score = 0.9664\n", + "Validation Target Estimator Error = 0.06763043769510688\n", + "Validation Target R2 Score = 0.8619\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 6\n", + "-------------------------------\n", + "9.798243761062622\n", + "Train Estimator Error = 0.016928718104180444\n", + "Train Estimator R2 Score = 0.9657\n", + "Train Domain Classifier Error = 0.0902507189198157\n", + "Validation Source Estimator Error = 0.014676664193653188\n", + "Validation Source R2 Score = 0.9693\n", + "Validation Target Estimator Error = 0.06337754338220426\n", + "Validation Target R2 Score = 0.8730\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 7\n", + "-------------------------------\n", + "11.475250482559204\n", + "Train Estimator Error = 0.01520067899678604\n", + "Train Estimator R2 Score = 0.9690\n", + "Train Domain Classifier Error = 0.08746750692971446\n", + "Validation Source Estimator Error = 0.015763865929144392\n", + "Validation Source R2 Score = 0.9671\n", + "Validation Target Estimator Error = 0.07552005605665361\n", + "Validation Target R2 Score = 0.8486\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 8\n", + "-------------------------------\n", + "9.42522406578064\n", + "Train Estimator Error = 0.014275324373621787\n", + "Train Estimator R2 Score = 0.9710\n", + "Train Domain Classifier Error = 0.08944729766323209\n", + "Validation Source Estimator Error = 0.013076007443296301\n", + "Validation Source R2 Score = 0.9731\n", + "Validation Target Estimator Error = 0.0584320479375162\n", + "Validation Target R2 Score = 0.8811\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 9\n", + "-------------------------------\n", + "12.132616519927979\n", + "Train Estimator Error = 0.013697150923138045\n", + "Train Estimator R2 Score = 0.9721\n", + "Train Domain Classifier Error = 0.0871505693820266\n", + "Validation Source Estimator Error = 0.015199173455405387\n", + "Validation Source R2 Score = 0.9685\n", + "Validation Target Estimator Error = 0.06418811832406339\n", + "Validation Target R2 Score = 0.8695\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 10\n", + "-------------------------------\n", + "10.557303428649902\n", + "Train Estimator Error = 0.012717660697401796\n", + "Train Estimator R2 Score = 0.9741\n", + "Train Domain Classifier Error = 0.08522086595806551\n", + "Validation Source Estimator Error = 0.011813055145536449\n", + "Validation Source R2 Score = 0.9757\n", + "Validation Target Estimator Error = 0.04445989502914202\n", + "Validation Target R2 Score = 0.9107\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 11\n", + "-------------------------------\n", + "9.70582914352417\n", + "Train Estimator Error = 0.01214050365462082\n", + "Train Estimator R2 Score = 0.9753\n", + "Train Domain Classifier Error = 0.0820563712908156\n", + "Validation Source Estimator Error = 0.011426687608384023\n", + "Validation Source R2 Score = 0.9760\n", + "Validation Target Estimator Error = 0.04615271602798799\n", + "Validation Target R2 Score = 0.9082\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 12\n", + "-------------------------------\n", + "9.581052541732788\n", + "Train Estimator Error = 0.011919633876123692\n", + "Train Estimator R2 Score = 0.9758\n", + "Train Domain Classifier Error = 0.08348346469750188\n", + "Validation Source Estimator Error = 0.010784041379714848\n", + "Validation Source R2 Score = 0.9775\n", + "Validation Target Estimator Error = 0.04491105257195367\n", + "Validation Target R2 Score = 0.9105\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 13\n", + "-------------------------------\n", + "9.942560195922852\n", + "Train Estimator Error = 0.011645885268685967\n", + "Train Estimator R2 Score = 0.9764\n", + "Train Domain Classifier Error = 0.08307299445940002\n", + "Validation Source Estimator Error = 0.010429152624183305\n", + "Validation Source R2 Score = 0.9783\n", + "Validation Target Estimator Error = 0.04398210141451875\n", + "Validation Target R2 Score = 0.9117\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 14\n", + "-------------------------------\n", + "9.535521030426025\n", + "Train Estimator Error = 0.010956779571413531\n", + "Train Estimator R2 Score = 0.9777\n", + "Train Domain Classifier Error = 0.08009117036220197\n", + "Validation Source Estimator Error = 0.01252956654591735\n", + "Validation Source R2 Score = 0.9742\n", + "Validation Target Estimator Error = 0.04393934647724697\n", + "Validation Target R2 Score = 0.9115\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 15\n", + "-------------------------------\n", + "10.049909353256226\n", + "Train Estimator Error = 0.011191575388063146\n", + "Train Estimator R2 Score = 0.9773\n", + "Train Domain Classifier Error = 0.08100781960232384\n", + "Validation Source Estimator Error = 0.010393610967109633\n", + "Validation Source R2 Score = 0.9787\n", + "Validation Target Estimator Error = 0.034813025829850866\n", + "Validation Target R2 Score = 0.9314\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 16\n", + "-------------------------------\n", + "9.887317895889282\n", + "Train Estimator Error = 0.010990695381099511\n", + "Train Estimator R2 Score = 0.9777\n", + "Train Domain Classifier Error = 0.07481679410402964\n", + "Validation Source Estimator Error = 0.010688897747261698\n", + "Validation Source R2 Score = 0.9780\n", + "Validation Target Estimator Error = 0.03671162581711913\n", + "Validation Target R2 Score = 0.9267\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 17\n", + "-------------------------------\n", + "9.46590256690979\n", + "Train Estimator Error = 0.01122740320363329\n", + "Train Estimator R2 Score = 0.9771\n", + "Train Domain Classifier Error = 0.07215353730602882\n", + "Validation Source Estimator Error = 0.01086703451516427\n", + "Validation Source R2 Score = 0.9776\n", + "Validation Target Estimator Error = 0.03929886768815244\n", + "Validation Target R2 Score = 0.9220\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 18\n", + "-------------------------------\n", + "9.881409645080566\n", + "Train Estimator Error = 0.012160148401361578\n", + "Train Estimator R2 Score = 0.9753\n", + "Train Domain Classifier Error = 0.06331457490854006\n", + "Validation Source Estimator Error = 0.011688765506171117\n", + "Validation Source R2 Score = 0.9757\n", + "Validation Target Estimator Error = 0.04073066228799\n", + "Validation Target R2 Score = 0.9182\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 19\n", + "-------------------------------\n", + "10.647077083587646\n", + "Train Estimator Error = 0.012665477483635479\n", + "Train Estimator R2 Score = 0.9743\n", + "Train Domain Classifier Error = 0.0531838871351353\n", + "Validation Source Estimator Error = 0.012146283566928024\n", + "Validation Source R2 Score = 0.9747\n", + "Validation Target Estimator Error = 0.039233959867221536\n", + "Validation Target R2 Score = 0.9208\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 20\n", + "-------------------------------\n", + "11.093253135681152\n", + "Train Estimator Error = 0.01234748291227987\n", + "Train Estimator R2 Score = 0.9749\n", + "Train Domain Classifier Error = 0.04573969768234265\n", + "Validation Source Estimator Error = 0.011225358962680504\n", + "Validation Source R2 Score = 0.9770\n", + "Validation Target Estimator Error = 0.037646287743737746\n", + "Validation Target R2 Score = 0.9244\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 21\n", + "-------------------------------\n", + "10.098066806793213\n", + "Train Estimator Error = 0.011807732323654526\n", + "Train Estimator R2 Score = 0.9760\n", + "Train Domain Classifier Error = 0.04173214546484801\n", + "Validation Source Estimator Error = 0.011837220317713774\n", + "Validation Source R2 Score = 0.9757\n", + "Validation Target Estimator Error = 0.035724040536079436\n", + "Validation Target R2 Score = 0.9286\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 22\n", + "-------------------------------\n", + "10.087324380874634\n", + "Train Estimator Error = 0.01155979186509288\n", + "Train Estimator R2 Score = 0.9765\n", + "Train Domain Classifier Error = 0.04175722094548958\n", + "Validation Source Estimator Error = 0.010796774510934854\n", + "Validation Source R2 Score = 0.9776\n", + "Validation Target Estimator Error = 0.029455781208386846\n", + "Validation Target R2 Score = 0.9411\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 23\n", + "-------------------------------\n", + "9.403812408447266\n", + "Train Estimator Error = 0.01096212370018943\n", + "Train Estimator R2 Score = 0.9779\n", + "Train Domain Classifier Error = 0.03727273999200879\n", + "Validation Source Estimator Error = 0.01076946327771256\n", + "Validation Source R2 Score = 0.9777\n", + "Validation Target Estimator Error = 0.034017571562509626\n", + "Validation Target R2 Score = 0.9327\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 24\n", + "-------------------------------\n", + "10.204989194869995\n", + "Train Estimator Error = 0.010513965448218192\n", + "Train Estimator R2 Score = 0.9787\n", + "Train Domain Classifier Error = 0.03472416911281005\n", + "Validation Source Estimator Error = 0.010430672994939385\n", + "Validation Source R2 Score = 0.9785\n", + "Validation Target Estimator Error = 0.033311633096568906\n", + "Validation Target R2 Score = 0.9334\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 25\n", + "-------------------------------\n", + "10.29259705543518\n", + "Train Estimator Error = 0.010646682369252036\n", + "Train Estimator R2 Score = 0.9785\n", + "Train Domain Classifier Error = 0.035981600340523875\n", + "Validation Source Estimator Error = 0.010258104230995012\n", + "Validation Source R2 Score = 0.9788\n", + "Validation Target Estimator Error = 0.03641296210728443\n", + "Validation Target R2 Score = 0.9272\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 26\n", + "-------------------------------\n", + "10.207979679107666\n", + "Train Estimator Error = 0.010566631928375723\n", + "Train Estimator R2 Score = 0.9785\n", + "Train Domain Classifier Error = 0.035830488049644824\n", + "Validation Source Estimator Error = 0.010909623131274608\n", + "Validation Source R2 Score = 0.9775\n", + "Validation Target Estimator Error = 0.03616505972210579\n", + "Validation Target R2 Score = 0.9278\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 27\n", + "-------------------------------\n", + "10.255443572998047\n", + "Train Estimator Error = 0.010196887276780671\n", + "Train Estimator R2 Score = 0.9793\n", + "Train Domain Classifier Error = 0.03093918035882891\n", + "Validation Source Estimator Error = 0.011571518453965141\n", + "Validation Source R2 Score = 0.9764\n", + "Validation Target Estimator Error = 0.03323280297599401\n", + "Validation Target R2 Score = 0.9339\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 28\n", + "-------------------------------\n", + "10.197303295135498\n", + "Train Estimator Error = 0.010003823992775811\n", + "Train Estimator R2 Score = 0.9797\n", + "Train Domain Classifier Error = 0.02865755154838074\n", + "Validation Source Estimator Error = 0.00992215172814763\n", + "Validation Source R2 Score = 0.9785\n", + "Validation Target Estimator Error = 0.03619385037310184\n", + "Validation Target R2 Score = 0.9273\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 29\n", + "-------------------------------\n", + "11.239193439483643\n", + "Train Estimator Error = 0.010029396919705573\n", + "Train Estimator R2 Score = 0.9797\n", + "Train Domain Classifier Error = 0.028030373441256324\n", + "Validation Source Estimator Error = 0.011323591759487701\n", + "Validation Source R2 Score = 0.9744\n", + "Validation Target Estimator Error = 0.038434194347518644\n", + "Validation Target R2 Score = 0.9225\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 30\n", + "-------------------------------\n", + "9.549391746520996\n", + "Train Estimator Error = 0.01042491172493982\n", + "Train Estimator R2 Score = 0.9789\n", + "Train Domain Classifier Error = 0.027439469181280412\n", + "Validation Source Estimator Error = 0.01315591714172891\n", + "Validation Source R2 Score = 0.9733\n", + "Validation Target Estimator Error = 0.03496130949752346\n", + "Validation Target R2 Score = 0.9303\n", + "Validation Domain Classifier Error = 1\n", + "\n" + ] + } + ], + "source": [ + "# Initialize dictionary for training stats\n", + "import time\n", + "model = NeuralNetwork().cuda()\n", + "# Hyper parameter presets\n", + "learning_rate = 6e-5\n", + "epochs = 30\n", + "# Define loss functions and optimizer\n", + "regressor_loss_fn = nn.MSELoss().cuda()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", + "da_loss = MMD_loss()\n", + "\n", + "stats = {'train_domain_classifier_error':[],\n", + " 'train_estimator_error':[],\n", + " 'train_score':[],\n", + " 'val_domain_classifier_error':[],\n", + " 'val_estimator_error':[],\n", + " 'val_estimator_error_target':[],\n", + " 'val_score':[],\n", + " 'val_score_target':[]}\n", + "\n", + "# Train\n", + "for i in range(epochs):\n", + " start_time = time.time()\n", + " print(f\"Epoch {i+1}\\n-------------------------------\")\n", + " vals = train_loop(source_train_dataloader, target_train_dataloader, model,\n", + " regressor_loss_fn, da_loss, optimizer, epochs, i)\n", + "\n", + " vals_validate = test_loop(source_val_dataloader, target_val_dataloader,\n", + " model, regressor_loss_fn, da_loss, epochs, i)\n", + " print(time.time() - start_time)\n", + "\n", + " stats['train_domain_classifier_error'].append(vals[0])\n", + " stats['train_estimator_error'].append(vals[1])\n", + " stats['train_score'].append(vals[2])\n", + " stats['val_domain_classifier_error'].append(vals_validate[0])\n", + " stats['val_estimator_error'].append(vals_validate[1])\n", + " stats['val_estimator_error_target'].append(vals_validate[2])\n", + " stats['val_score'].append(vals_validate[3])\n", + " stats['val_score_target'].append(vals_validate[4])\n", + "\n", + " to_print = (\n", + " f'Train Estimator Error = {vals[1]}\\n'\n", + " f'Train Estimator R2 Score = {vals[2]:.4f}\\n'\n", + " f'Train Domain Classifier Error = {vals[0]}\\n'\n", + " f'Validation Source Estimator Error = {vals_validate[1]}\\n'\n", + " f'Validation Source R2 Score = {vals_validate[3]:.4f}\\n'\n", + " f'Validation Target Estimator Error = {vals_validate[2]}\\n'\n", + " f'Validation Target R2 Score = {vals_validate[4]:.4f}\\n'\n", + " f'Validation Domain Classifier Error = {vals_validate[0]}\\n'\n", + " )\n", + "\n", + " print(to_print)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "YfplCDIb-UU_", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 649, + "status": "ok", + "timestamp": 1718869045736, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "YfplCDIb-UU_", + "outputId": "dbb362ec-4af5-4cb9-c4f9-a0a2766c26c5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Classifier\n", + "eps = np.arange(epochs)\n", + "plt.title(\"Classifier Error\")\n", + "plt.plot(eps, stats['train_domain_classifier_error'])\n", + "plt.plot(eps, stats['val_domain_classifier_error'])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "eYG_P_iQ_5Bv", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 169, + "status": "ok", + "timestamp": 1718869045739, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "eYG_P_iQ_5Bv", + "outputId": "be450f92-eda7-4e4f-81fe-008c55b2b112" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Estimator\n", + "plt.title(\"Estimator Error\")\n", + "plt.plot(eps, stats['train_estimator_error'])\n", + "plt.plot(eps, stats['val_estimator_error'])\n", + "plt.plot(eps, stats['val_estimator_error_target'])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "xS9rtS-T_neg", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 237, + "status": "ok", + "timestamp": 1718869045904, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "xS9rtS-T_neg", + "outputId": "d32f40ef-6042-4154-e9ee-1f4e2f90064d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# R2 Scores\n", + "plt.title(\"R2 Scores\")\n", + "plt.plot(eps, stats['train_score'])\n", + "plt.plot(eps, stats['val_score'])\n", + "plt.plot(eps, stats['val_score_target'])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ed0a8206-7520-4a60-8e17-965a91133b92", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 428 + }, + "executionInfo": { + "elapsed": 969, + "status": "ok", + "timestamp": 1718869046858, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "ed0a8206-7520-4a60-8e17-965a91133b92", + "outputId": "7df8c563-5826-4e43-d9e6-5e686463551d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Source R2 Score is 0.9742\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'MMD - Source')" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Test Source\n", + "preds = np.array([])\n", + "true = np.array([])\n", + "score_list = np.array([])\n", + "\n", + "with torch.no_grad():\n", + " for X, y in source_test_dataloader:\n", + " X = X.float()\n", + " pred, _ = model(X.cuda())\n", + " preds = np.append(preds, pred.cpu())\n", + " true = np.append(true, y.cpu())\n", + " score = r2_score(y.cpu(), pred.cpu())\n", + " score_list = np.append(score_list, score)\n", + "\n", + "score = np.mean(score_list)\n", + "print(f'Source R2 Score is {score:.4f}')\n", + "\n", + "plt.figure(figsize=(8,8),dpi=50)\n", + "plt.scatter(true, preds, color='black', alpha = 0.05)\n", + "line = np.linspace(0, 4, 100)\n", + "plt.plot(line, line)\n", + "plt.rc('font', size=12)\n", + "plt.xlabel('True Theta E')\n", + "plt.ylabel('Predicted Theta E');\n", + "plt.rc('font', size=20)\n", + "plt.title('MMD - Source')" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "fc047cd7-bc92-4a30-9beb-7af607da141f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 444 + }, + "executionInfo": { + "elapsed": 1283, + "status": "ok", + "timestamp": 1718869048133, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "fc047cd7-bc92-4a30-9beb-7af607da141f", + "outputId": "b6347093-56d9-4a8b-b515-c4c4717cdab4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target R2 Score is 0.9299\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'MMD - Target')" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Test target\n", + "preds = np.array([])\n", + "true = np.array([])\n", + "score_list = np.array([])\n", + "\n", + "with torch.no_grad():\n", + " for X, y in target_test_dataloader:\n", + " X = X.float()\n", + " pred, _ = model(X.cuda())\n", + " preds = np.append(preds, pred.cpu())\n", + " true = np.append(true, y.cpu())\n", + " score = r2_score(y.cpu(), pred.cpu())\n", + " score_list = np.append(score_list, score)\n", + "\n", + "score = np.mean(score_list)\n", + "print(f'Target R2 Score is {score:.4f}')\n", + "\n", + "plt.figure(figsize=(8,8),dpi=50)\n", + "plt.scatter(true, preds, color='black', alpha = 0.05)\n", + "line = np.linspace(0, 4, 100)\n", + "plt.plot(line, line)\n", + "plt.rc('font', size=12)\n", + "plt.xlabel('True Theta E')\n", + "plt.ylabel('Predicted Theta E');\n", + "plt.rc('font', size=20)\n", + "plt.title('MMD - Target')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14a94f1e-758e-4a64-b0c7-0f3a5781f7c2", + "metadata": { + "id": "14a94f1e-758e-4a64-b0c7-0f3a5781f7c2" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [ + { + "file_id": "1MFScb-3Sbugn4RNiDaeocicJUIHlh_j2", + "timestamp": 1717430435817 + }, + { + "file_id": "1wlKaSdLzleueYrwljtOcqsiOfzEy1dxP", + "timestamp": 1717429638462 + } + ] + }, + "kernelspec": { + "display_name": "Python 3 (Safe Mode)", + "language": "python", + "name": "py3-safemode" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/training/notebooks/MMD_paper/multiband/ShrihanPaperMMD_mb.ipynb b/training/notebooks/MMD_paper/multiband/ShrihanPaperMMD_mb.ipynb new file mode 100644 index 0000000..30bbb4f --- /dev/null +++ b/training/notebooks/MMD_paper/multiband/ShrihanPaperMMD_mb.ipynb @@ -0,0 +1,1005 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a8aa3fe5-4277-47fc-b26d-baa137256f17", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 10375, + "status": "ok", + "timestamp": 1718868666013, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "a8aa3fe5-4277-47fc-b26d-baa137256f17", + "outputId": "9ad89b68-4fd0-4146-a087-24cd367fb09f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n" + ] + } + ], + "source": [ + "# Imports we will use\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "from torch.autograd import Function\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import random\n", + "from pathlib import Path\n", + "from sklearn.metrics import r2_score\n", + "from astropy.visualization import make_lupton_rgb\n", + "\n", + "# For matplotlib\n", + "import os\n", + "os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n", + "\n", + "# Set Seed\n", + "torch.manual_seed(22)\n", + "\n", + "# Find if cuda is available\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using {device} device\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7cc92062-1846-4850-8f8e-206a7c35c171", + "metadata": { + "executionInfo": { + "elapsed": 189, + "status": "ok", + "timestamp": 1718868679894, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "7cc92062-1846-4850-8f8e-206a7c35c171" + }, + "outputs": [], + "source": [ + "# Load data function\n", + "def create_dataloader(img_path, metadata_path, batch_size):\n", + " '''\n", + " Creates dataloader for training, reserving the last 10% images for validation/testing\n", + " '''\n", + " data = np.load(img_path).squeeze()\n", + " length = len(data)\n", + " data_train = torch.tensor(data[:int(.7*length)]) # 70% train\n", + " data_test = torch.tensor(data[int(.7*length):int(.9*length)]) # 20% test\n", + " data_val = torch.tensor(data[int(.9*length):]) # 10% validation\n", + "\n", + " metadata = pd.read_csv(metadata_path)\n", + " labels = metadata['PLANE_1-OBJECT_1-MASS_PROFILE_1-theta_E-g'].tolist()\n", + " labels_train = torch.tensor(labels[:int(.7*length)])\n", + " labels_test = torch.tensor(labels[int(.7*length):int(.9*length)])\n", + " labels_val = torch.tensor(labels[int(.9*length):])\n", + "\n", + " data_train.cuda()\n", + " data_test.cuda()\n", + " data_val.cuda()\n", + " labels_train.cuda()\n", + " labels_test.cuda()\n", + " labels_val.cuda()\n", + "\n", + " train_dataset = TensorDataset(data_train, labels_train)\n", + " test_dataset = TensorDataset(data_test, labels_test)\n", + " val_dataset = TensorDataset(data_val, labels_val)\n", + "\n", + " train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n", + " test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)\n", + " val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)\n", + "\n", + " return train_dataloader, test_dataloader, val_dataloader, data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3efc6755-daeb-48ca-bbc7-c5a3b539c5b7", + "metadata": { + "executionInfo": { + "elapsed": 19938, + "status": "ok", + "timestamp": 1718868749575, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "3efc6755-daeb-48ca-bbc7-c5a3b539c5b7" + }, + "outputs": [], + "source": [ + "# Load in data\n", + "head = Path.cwd().parents[3]\n", + "source_img_path = head / 'data/mb_source/mb_source.npy'\n", + "target_img_path = head / 'data/mb_target/mb_target.npy'\n", + "source_meta = head / 'data/mb_source/mb_source_metadata.csv'\n", + "target_meta = head / 'data/mb_target/mb_target_metadata.csv'\n", + "batch_size = 32\n", + "source_train_dataloader, source_test_dataloader, source_val_dataloader, source_data = create_dataloader(source_img_path, source_meta, batch_size)\n", + "target_train_dataloader, target_test_dataloader, target_val_dataloader, target_data = create_dataloader(target_img_path, target_meta, batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cc2641b2-6b2f-4cd7-9b29-a8ed7a595103", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "source_train_dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a3045daa-2e71-4335-8259-662a5c7e41a8", + "metadata": { + "executionInfo": { + "elapsed": 3, + "status": "ok", + "timestamp": 1718868749576, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "a3045daa-2e71-4335-8259-662a5c7e41a8" + }, + "outputs": [], + "source": [ + "# Define data visualization function\n", + "def visualize_data(data):\n", + " '''\n", + " visualizes 16 random images from dataset\n", + " '''\n", + " \n", + " data_length = len(data)\n", + " num_indices = 16\n", + " \n", + " # Generate 15 unique random indices using numpy\n", + " random_indices = np.random.choice(data_length, size=num_indices, replace=False)\n", + "\n", + " #plot the examples for source\n", + " fig1=plt.figure(figsize=(8,8))\n", + "\n", + " for i in range(16):\n", + " plt.subplot(4, 4, i + 1)\n", + " plt.axis(\"off\")\n", + "\n", + " img = data[random_indices[i]]\n", + " example_image = make_lupton_rgb(img[0], img[1], img[2]) #change band by switching 0:1 to 1:2 or 2:3\n", + "\n", + " plt.imshow(example_image, aspect='auto', cmap='viridis')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b72c4588-acb2-478c-96e9-cb09a0380ecd", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 673 + }, + "executionInfo": { + "elapsed": 559, + "status": "ok", + "timestamp": 1718868750133, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "b72c4588-acb2-478c-96e9-cb09a0380ecd", + "outputId": "651cb9ac-efea-4f14-b3a0-f03648a4081a" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize source data\n", + "visualize_data(source_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6d6e4147-ce23-4fca-b1aa-42122b0e2501", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 673 + }, + "executionInfo": { + "elapsed": 665, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "6d6e4147-ce23-4fca-b1aa-42122b0e2501", + "outputId": "eccb0d95-4566-445f-a058-b1d5b87765b0" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize target data\n", + "visualize_data(target_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7b706147-6d5c-4319-a7b0-87decc1e6a7f", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "7b706147-6d5c-4319-a7b0-87decc1e6a7f" + }, + "outputs": [], + "source": [ + "# Define and initialize model\n", + "class NeuralNetwork(nn.Module):\n", + " def __init__(self):\n", + " super(NeuralNetwork, self).__init__()\n", + " self.feature = nn.Sequential()\n", + " self.feature.add_module('f_conv1', nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu1', nn.ReLU(True))\n", + " self.feature.add_module('f_bn1', nn.BatchNorm2d(8))\n", + " self.feature.add_module('f_pool1', nn.MaxPool2d(kernel_size=2, stride=2))\n", + " self.feature.add_module('f_conv2', nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu2', nn.ReLU(True))\n", + " self.feature.add_module('f_bn2', nn.BatchNorm2d(16))\n", + " self.feature.add_module('f_pool2', nn.MaxPool2d(kernel_size=2, stride=2))\n", + " self.feature.add_module('f_conv3', nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding='same'))\n", + " self.feature.add_module('f_relu3', nn.ReLU(True))\n", + " self.feature.add_module('f_bn3', nn.BatchNorm2d(32))\n", + " self.feature.add_module('f_pool3', nn.MaxPool2d(kernel_size=2, stride=2))\n", + "\n", + " self.regressor = nn.Sequential()\n", + " self.regressor.add_module('r_fc1', nn.Linear(in_features=32*5*5, out_features=128))\n", + " self.regressor.add_module('r_relu1', nn.ReLU(True))\n", + " #self.regressor.add_module('r_fc2', nn.Linear(in_features=128, out_features=64))\n", + " #self.regressor.add_module('r_relu2', nn.ReLU(True))\n", + " self.regressor.add_module('r_fc3', nn.Linear(in_features=128, out_features=1))\n", + "\n", + " def forward(self, x):\n", + " x = x.view(-1, 1, 40, 40)\n", + "\n", + " features = self.feature(x)\n", + " features = features.view(-1, 32*5*5)\n", + " estimate = self.regressor(features)\n", + " estimate = F.relu(estimate)\n", + " estimate = estimate.view(-1)\n", + "\n", + " return estimate, features\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cfd79aed-d467-4d59-a44d-df05177dfd58", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750796, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "cfd79aed-d467-4d59-a44d-df05177dfd58" + }, + "outputs": [], + "source": [ + "# code from https://github.com/ZongxianLee/MMD_Loss.Pytorch\n", + "\n", + "class MMD_loss(nn.Module):\n", + " def __init__(self, kernel_mul = 2.0, kernel_num = 5):\n", + " super(MMD_loss, self).__init__()\n", + " self.kernel_num = kernel_num\n", + " self.kernel_mul = kernel_mul\n", + " self.fix_sigma = None\n", + " return\n", + " def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n", + " n_samples = int(source.size()[0])+int(target.size()[0])\n", + " total = torch.cat([source, target], dim=0)\n", + "\n", + " total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", + " total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", + " L2_distance = ((total0-total1)**2).sum(2)\n", + " if fix_sigma:\n", + " bandwidth = fix_sigma\n", + " else:\n", + " bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)\n", + " bandwidth /= kernel_mul ** (kernel_num // 2)\n", + " bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]\n", + " kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]\n", + " return sum(kernel_val)\n", + "\n", + " def forward(self, source, target):\n", + " batch_size = int(source.size()[0])\n", + " kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)\n", + " XX = kernels[:batch_size, :batch_size]\n", + " YY = kernels[batch_size:, batch_size:]\n", + " XY = kernels[:batch_size, batch_size:]\n", + " YX = kernels[batch_size:, :batch_size]\n", + " loss = torch.mean(XX + YY - XY -YX)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ccac040a-7d18-45a4-b390-40e3dfa51756", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750797, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "ccac040a-7d18-45a4-b390-40e3dfa51756" + }, + "outputs": [], + "source": [ + "# Define training loop\n", + "def train_loop(source_dataloader, target_dataloader, model, regressor_loss_fn, da_loss, optimizer, n_epoch, epoch):\n", + "\n", + " domain_error = 0\n", + " domain_classifier_accuracy = 0\n", + " estimator_error = 0\n", + " score_list = np.array([])\n", + "\n", + " len_dataloader = min(len(source_dataloader), len(target_dataloader))\n", + " data_source_iter = iter(source_dataloader)\n", + " data_target_iter = iter(target_dataloader)\n", + "\n", + " i = 0\n", + " while i < len_dataloader:\n", + "\n", + " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", + " alpha = 2. / (1. + np.exp(-10 * p)) - 1\n", + "\n", + " # Source Training\n", + "\n", + " data_source = next(data_source_iter)\n", + " X, y = data_source\n", + " X = X.float()\n", + " X = X.cuda()\n", + " y = y.cuda()\n", + "\n", + " model.zero_grad()\n", + " batch_size = len(y)\n", + "\n", + " domain_label = torch.zeros(batch_size)\n", + " domain_label = domain_label.long()\n", + " domain_label = domain_label.cuda()\n", + "\n", + " estimate_output, domain_output_source = model(X)\n", + "\n", + " estimate_loss = regressor_loss_fn(estimate_output, y)\n", + "\n", + " # Target Training\n", + "\n", + " data_target = next(data_target_iter)\n", + " X_target, _ = data_target\n", + " X_target = X_target.float()\n", + " X_target = X_target.cuda()\n", + "\n", + " batch_size = len(X_target)\n", + "\n", + " _, domain_output_target = model(X_target)\n", + " domain_loss = da_loss(domain_output_source, domain_output_target)\n", + "\n", + " loss = estimate_loss + domain_loss*1.4\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Update values\n", + "\n", + " domain_error += domain_loss.item()\n", + " #domain_classifier_accuracy +=\n", + " estimator_error += estimate_loss.item()\n", + " score = r2_score(y.cpu().detach().numpy(), estimate_output.cpu().detach().numpy())\n", + " score_list = np.append(score_list, score)\n", + "\n", + " i += 1\n", + "\n", + " score = np.mean(score_list)\n", + " domain_error = domain_error / (len_dataloader)\n", + " estimator_error /= len_dataloader\n", + "\n", + " return [domain_error, estimator_error, score]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "98583af6-1fbb-4091-bc22-b1ce362e8f21", + "metadata": { + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1718868750797, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "98583af6-1fbb-4091-bc22-b1ce362e8f21" + }, + "outputs": [], + "source": [ + "# Define testing loop\n", + "\n", + "def test_loop(source_dataloader, target_dataloader, model, regressor_loss_fn, da_loss, n_epoch, epoch):\n", + "\n", + " with torch.no_grad():\n", + "\n", + " len_dataloader = min(len(source_dataloader), len(target_dataloader))\n", + " data_source_iter = iter(source_dataloader)\n", + " data_target_iter = iter(target_dataloader)\n", + "\n", + " domain_classifier_error = 0\n", + " domain_classifier_accuracy = 0\n", + " estimator_error = 0\n", + " estimator_error_target = 0\n", + " score_list = np.array([])\n", + " score_list_target = np.array([])\n", + "\n", + " i = 0\n", + " while i < len_dataloader:\n", + "\n", + " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", + " alpha = 2. / (1. + np.exp(-10 * p)) - 1\n", + "\n", + " # Source Testing\n", + "\n", + " data_source = next(data_source_iter)\n", + " X, y = data_source\n", + " X = X.float()\n", + " X = X.cuda()\n", + " y = y.cuda()\n", + "\n", + " batch_size = len(y)\n", + "\n", + " #domain_label = torch.zeros(batch_size)\n", + " #domain_label = domain_label.long()\n", + " #domain_label = domain_label.cuda()\n", + "\n", + " estimate_output, domain_output = model(X)\n", + "\n", + " estimate_loss = regressor_loss_fn(estimate_output, y)\n", + " #domain_loss_source = classifier_loss_fn(domain_output, domain_label)\n", + "\n", + " # Target Testing\n", + "\n", + " data_target = next(data_target_iter)\n", + " X_target, y_target = data_target\n", + " X_target = X_target.float()\n", + " X_target = X_target.cuda()\n", + " y_target = y_target.cuda()\n", + "\n", + " batch_size = len(X_target)\n", + "\n", + " #domain_label = torch.ones(batch_size)\n", + " #domain_label = domain_label.long()\n", + " #domain_label = domain_label.cuda()\n", + "\n", + " estimate_output_target, domain_output = model(X_target)\n", + "\n", + " estimate_loss_target = regressor_loss_fn(estimate_output_target, y_target)\n", + " #domain_loss_target = classifier_loss_fn(domain_output, domain_label)\n", + "\n", + " # Update values\n", + "\n", + " # domain_classifier_error += domain_loss_source.item()\n", + " #domain_classifier_error += domain_loss_target.item()\n", + " #domain_classifier_accuracy +=\n", + " estimator_error += estimate_loss.item()\n", + " estimator_error_target += estimate_loss_target.item()\n", + " score = r2_score(y.cpu(), estimate_output.cpu())\n", + " score_list = np.append(score_list, score)\n", + " score_target = r2_score(y_target.cpu(), estimate_output_target.cpu())\n", + " score_list_target = np.append(score_list_target, score_target)\n", + "\n", + " i += 1\n", + "\n", + " score = np.mean(score_list)\n", + " score_target = np.mean(score_list_target)\n", + " #classifier_error = domain_classifier_error / (len_dataloader * 2)\n", + " estimator_error /= len_dataloader\n", + " estimator_error_target /= len_dataloader\n", + " classifier_error = 1\n", + " return [classifier_error, estimator_error, estimator_error_target, score, score_target]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1dfe3810-672c-4a28-b606-b3079a40fca4", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 293833, + "status": "ok", + "timestamp": 1718869045423, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "1dfe3810-672c-4a28-b606-b3079a40fca4", + "outputId": "45493f2a-ea42-401e-f88b-b0ad39b969ed" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1\n", + "-------------------------------\n", + "12.33421277999878\n", + "Train Estimator Error = 0.16444933820188973\n", + "Train Estimator R2 Score = 0.6710\n", + "Train Domain Classifier Error = 0.197300594592879\n", + "Validation Source Estimator Error = 0.03957607594739859\n", + "Validation Source R2 Score = 0.9181\n", + "Validation Target Estimator Error = 0.17865040874595095\n", + "Validation Target R2 Score = 0.6406\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 2\n", + "-------------------------------\n", + "10.286649942398071\n", + "Train Estimator Error = 0.033987110668803534\n", + "Train Estimator R2 Score = 0.9313\n", + "Train Domain Classifier Error = 0.10603604664246277\n", + "Validation Source Estimator Error = 0.026627989835847334\n", + "Validation Source R2 Score = 0.9447\n", + "Validation Target Estimator Error = 0.12391905738100124\n", + "Validation Target R2 Score = 0.7497\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 3\n", + "-------------------------------\n", + "10.679370164871216\n", + "Train Estimator Error = 0.025708429421718748\n", + "Train Estimator R2 Score = 0.9480\n", + "Train Domain Classifier Error = 0.09875815365143406\n", + "Validation Source Estimator Error = 0.025580009335806224\n", + "Validation Source R2 Score = 0.9470\n", + "Validation Target Estimator Error = 0.11177382997836277\n", + "Validation Target R2 Score = 0.7764\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 4\n", + "-------------------------------\n", + "9.528148651123047\n", + "Train Estimator Error = 0.021674147663191916\n", + "Train Estimator R2 Score = 0.9560\n", + "Train Domain Classifier Error = 0.09356177005732953\n", + "Validation Source Estimator Error = 0.023202258696079635\n", + "Validation Source R2 Score = 0.9526\n", + "Validation Target Estimator Error = 0.09558532137874585\n", + "Validation Target R2 Score = 0.8068\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 5\n", + "-------------------------------\n", + "9.20451831817627\n", + "Train Estimator Error = 0.018606798048258863\n", + "Train Estimator R2 Score = 0.9622\n", + "Train Domain Classifier Error = 0.09366841838989659\n", + "Validation Source Estimator Error = 0.016288266745603578\n", + "Validation Source R2 Score = 0.9664\n", + "Validation Target Estimator Error = 0.06763043769510688\n", + "Validation Target R2 Score = 0.8619\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 6\n", + "-------------------------------\n", + "9.798243761062622\n", + "Train Estimator Error = 0.016928718104180444\n", + "Train Estimator R2 Score = 0.9657\n", + "Train Domain Classifier Error = 0.0902507189198157\n", + "Validation Source Estimator Error = 0.014676664193653188\n", + "Validation Source R2 Score = 0.9693\n", + "Validation Target Estimator Error = 0.06337754338220426\n", + "Validation Target R2 Score = 0.8730\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 7\n", + "-------------------------------\n", + "11.475250482559204\n", + "Train Estimator Error = 0.01520067899678604\n", + "Train Estimator R2 Score = 0.9690\n", + "Train Domain Classifier Error = 0.08746750692971446\n", + "Validation Source Estimator Error = 0.015763865929144392\n", + "Validation Source R2 Score = 0.9671\n", + "Validation Target Estimator Error = 0.07552005605665361\n", + "Validation Target R2 Score = 0.8486\n", + "Validation Domain Classifier Error = 1\n", + "\n", + "Epoch 8\n", + "-------------------------------\n" + ] + } + ], + "source": [ + "# Initialize dictionary for training stats\n", + "import time\n", + "model = NeuralNetwork().cuda()\n", + "# Hyper parameter presets\n", + "learning_rate = 6e-5\n", + "epochs = 30\n", + "# Define loss functions and optimizer\n", + "regressor_loss_fn = nn.MSELoss().cuda()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", + "da_loss = MMD_loss()\n", + "\n", + "stats = {'train_domain_classifier_error':[],\n", + " 'train_estimator_error':[],\n", + " 'train_score':[],\n", + " 'val_domain_classifier_error':[],\n", + " 'val_estimator_error':[],\n", + " 'val_estimator_error_target':[],\n", + " 'val_score':[],\n", + " 'val_score_target':[]}\n", + "\n", + "# Train\n", + "for i in range(epochs):\n", + " start_time = time.time()\n", + " print(f\"Epoch {i+1}\\n-------------------------------\")\n", + " vals = train_loop(source_train_dataloader, target_train_dataloader, model,\n", + " regressor_loss_fn, da_loss, optimizer, epochs, i)\n", + "\n", + " vals_validate = test_loop(source_val_dataloader, target_val_dataloader,\n", + " model, regressor_loss_fn, da_loss, epochs, i)\n", + " print(time.time() - start_time)\n", + "\n", + " stats['train_domain_classifier_error'].append(vals[0])\n", + " stats['train_estimator_error'].append(vals[1])\n", + " stats['train_score'].append(vals[2])\n", + " stats['val_domain_classifier_error'].append(vals_validate[0])\n", + " stats['val_estimator_error'].append(vals_validate[1])\n", + " stats['val_estimator_error_target'].append(vals_validate[2])\n", + " stats['val_score'].append(vals_validate[3])\n", + " stats['val_score_target'].append(vals_validate[4])\n", + "\n", + " to_print = (\n", + " f'Train Estimator Error = {vals[1]}\\n'\n", + " f'Train Estimator R2 Score = {vals[2]:.4f}\\n'\n", + " f'Train Domain Classifier Error = {vals[0]}\\n'\n", + " f'Validation Source Estimator Error = {vals_validate[1]}\\n'\n", + " f'Validation Source R2 Score = {vals_validate[3]:.4f}\\n'\n", + " f'Validation Target Estimator Error = {vals_validate[2]}\\n'\n", + " f'Validation Target R2 Score = {vals_validate[4]:.4f}\\n'\n", + " f'Validation Domain Classifier Error = {vals_validate[0]}\\n'\n", + " )\n", + "\n", + " print(to_print)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "YfplCDIb-UU_", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 649, + "status": "ok", + "timestamp": 1718869045736, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "YfplCDIb-UU_", + "outputId": "dbb362ec-4af5-4cb9-c4f9-a0a2766c26c5" + }, + "outputs": [], + "source": [ + "# Classifier\n", + "eps = np.arange(epochs)\n", + "plt.title(\"Classifier Error\")\n", + "plt.plot(eps, stats['train_domain_classifier_error'])\n", + "plt.plot(eps, stats['val_domain_classifier_error'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eYG_P_iQ_5Bv", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 169, + "status": "ok", + "timestamp": 1718869045739, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "eYG_P_iQ_5Bv", + "outputId": "be450f92-eda7-4e4f-81fe-008c55b2b112" + }, + "outputs": [], + "source": [ + "# Estimator\n", + "plt.title(\"Estimator Error\")\n", + "plt.plot(eps, stats['train_estimator_error'])\n", + "plt.plot(eps, stats['val_estimator_error'])\n", + "plt.plot(eps, stats['val_estimator_error_target'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "xS9rtS-T_neg", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 490 + }, + "executionInfo": { + "elapsed": 237, + "status": "ok", + "timestamp": 1718869045904, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "xS9rtS-T_neg", + "outputId": "d32f40ef-6042-4154-e9ee-1f4e2f90064d" + }, + "outputs": [], + "source": [ + "# R2 Scores\n", + "plt.title(\"R2 Scores\")\n", + "plt.plot(eps, stats['train_score'])\n", + "plt.plot(eps, stats['val_score'])\n", + "plt.plot(eps, stats['val_score_target'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed0a8206-7520-4a60-8e17-965a91133b92", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 428 + }, + "executionInfo": { + "elapsed": 969, + "status": "ok", + "timestamp": 1718869046858, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "ed0a8206-7520-4a60-8e17-965a91133b92", + "outputId": "7df8c563-5826-4e43-d9e6-5e686463551d" + }, + "outputs": [], + "source": [ + "# Test Source\n", + "preds = np.array([])\n", + "true = np.array([])\n", + "score_list = np.array([])\n", + "\n", + "with torch.no_grad():\n", + " for X, y in source_test_dataloader:\n", + " X = X.float()\n", + " pred, _ = model(X.cuda())\n", + " preds = np.append(preds, pred.cpu())\n", + " true = np.append(true, y.cpu())\n", + " score = r2_score(y.cpu(), pred.cpu())\n", + " score_list = np.append(score_list, score)\n", + "\n", + "score = np.mean(score_list)\n", + "print(f'Source R2 Score is {score:.4f}')\n", + "\n", + "plt.figure(figsize=(8,8),dpi=50)\n", + "plt.scatter(true, preds, color='black')\n", + "line = np.linspace(0, 4, 100)\n", + "plt.plot(line, line)\n", + "plt.rc('font', size=12)\n", + "plt.xlabel('True Theta E')\n", + "plt.ylabel('Predicted Theta E');\n", + "plt.rc('font', size=20)\n", + "plt.title('MMD - Source')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc047cd7-bc92-4a30-9beb-7af607da141f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 444 + }, + "executionInfo": { + "elapsed": 1283, + "status": "ok", + "timestamp": 1718869048133, + "user": { + "displayName": "Shrihan Agarwal", + "userId": "00018416289398983661" + }, + "user_tz": 300 + }, + "id": "fc047cd7-bc92-4a30-9beb-7af607da141f", + "outputId": "b6347093-56d9-4a8b-b515-c4c4717cdab4" + }, + "outputs": [], + "source": [ + "# Test target\n", + "preds = np.array([])\n", + "true = np.array([])\n", + "score_list = np.array([])\n", + "\n", + "with torch.no_grad():\n", + " for X, y in target_test_dataloader:\n", + " X = X.float()\n", + " pred, _ = model(X.cuda())\n", + " preds = np.append(preds, pred.cpu())\n", + " true = np.append(true, y.cpu())\n", + " score = r2_score(y.cpu(), pred.cpu())\n", + " score_list = np.append(score_list, score)\n", + "\n", + "score = np.mean(score_list)\n", + "print(f'Target R2 Score is {score:.4f}')\n", + "\n", + "plt.figure(figsize=(8,8),dpi=50)\n", + "plt.scatter(true, preds, color='black')\n", + "line = np.linspace(0, 4, 100)\n", + "plt.plot(line, line)\n", + "plt.rc('font', size=12)\n", + "plt.xlabel('True Theta E')\n", + "plt.ylabel('Predicted Theta E');\n", + "plt.rc('font', size=20)\n", + "plt.title('MMD - Target')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14a94f1e-758e-4a64-b0c7-0f3a5781f7c2", + "metadata": { + "id": "14a94f1e-758e-4a64-b0c7-0f3a5781f7c2" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [ + { + "file_id": "1MFScb-3Sbugn4RNiDaeocicJUIHlh_j2", + "timestamp": 1717430435817 + }, + { + "file_id": "1wlKaSdLzleueYrwljtOcqsiOfzEy1dxP", + "timestamp": 1717429638462 + } + ] + }, + "kernelspec": { + "display_name": "Python [conda env:.conda-neural]", + "language": "python", + "name": "conda-env-.conda-neural-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/ShrihanPaperMMD_MinMaxNorm.ipynb b/training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_MinMaxNorm.ipynb similarity index 100% rename from notebooks/ShrihanPaperMMD_MinMaxNorm.ipynb rename to training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_MinMaxNorm.ipynb diff --git a/notebooks/ShrihanPaperMMD_Norm.ipynb b/training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_Norm.ipynb similarity index 100% rename from notebooks/ShrihanPaperMMD_Norm.ipynb rename to training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_Norm.ipynb diff --git a/notebooks/ShrihanPaperMMD_Norm2.ipynb b/training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_Norm2.ipynb similarity index 100% rename from notebooks/ShrihanPaperMMD_Norm2.ipynb rename to training/notebooks/MMD_paper/normalization/ShrihanPaperMMD_Norm2.ipynb diff --git a/notebooks/mmd_to_send.ipynb b/training/notebooks/MMD_paper/original_mmdpaper_notebook.ipynb similarity index 100% rename from notebooks/mmd_to_send.ipynb rename to training/notebooks/MMD_paper/original_mmdpaper_notebook.ipynb diff --git a/training/scripts/__init__.py b/training/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/scripts/__version__.py b/training/scripts/__version__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/scripts/evaluate.py b/training/scripts/evaluate.py new file mode 100644 index 0000000..664d20e --- /dev/null +++ b/training/scripts/evaluate.py @@ -0,0 +1,46 @@ +""" +Simple stub functions to use in inference +""" + +import argparse + + +def load_model(checkpoint_path): + """ + Load the entire model for prediction with an input + + :param checkpoint_path: location + :return: loaded model object that can be used with the predict function + """ + pass + + +def predict(input, model): + """ + + :param input: loaded object used for inference + :param model: loaded model + :return: Prediction + """ + return 0 + +def load_inference_object(input_path): + """ + + :param input_path: path to the object you want to predict + :return: loaded object + """ + return 0 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, help="Checkpoint to unloaded model checkpoint, either weights or the compressed model object") + parser.add_argument("--input", type=str, help="path to object to predict quality of") + args = parser.parse_args() + + model = load_model(args.checkpoint) + pred_obj = load_inference_object(args.input) + + prediction = predict(pred_obj, model) + print(prediction) diff --git a/training/scripts/paths.py b/training/scripts/paths.py new file mode 100644 index 0000000..8c9434e --- /dev/null +++ b/training/scripts/paths.py @@ -0,0 +1,29 @@ +""" +Exposes common paths useful for manipulating datasets and generating figures. + +""" +from pathlib import Path + +# Absolute path to the top level of the repository +root = Path(__file__).resolve().parents[2].absolute() + +# Absolute path to the `src` folder +src = root / "src" + +# Absolute path to the `src/data` folder (contains datasets) +data = src / "data" + +# Absolute path to the `src/static` folder (contains static images) +static = src / "static" + +# Absolute path to the `src/scripts` folder (contains figure/pipeline scripts) +scripts = src / "scripts" + +# Absolute path to the `src/tex` folder (contains the manuscript) +tex = src / "tex" + +# Absolute path to the `src/tex/figures` folder (contains figure output) +figures = tex / "figures" + +# Absolute path to the `src/tex/output` folder (contains other user-defined output) +output = tex / "output" \ No newline at end of file diff --git a/training/scripts/train.py b/training/scripts/train.py new file mode 100644 index 0000000..a24465b --- /dev/null +++ b/training/scripts/train.py @@ -0,0 +1,39 @@ +""" +Simple stubs to use for re-train of the final model +Can leave a default data source, or specify that 'load data' loads the dataset used in the final version +""" +import argparse + + +def architecture(): + """ + :return: compiled architecture of the model you want to have trained + """ + return 0 + +def load_data(data_source): + """ + :return: data loader or full training data, split in val and train + """ + return 0, 0 + +def train_model(data_source, n_epochs): + """ + :param data_source: + :param n_epochs: + :return: trained model, or simply None, but saved trained model + """ + data = load_data(data_source) + model = architecture() + + return 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_source", type=str, help="Data used to train the model") + parser.add_argument("--n_epochs", type=int, help='Integer number of epochs to train the model') + + args = parser.parse_args() + + train_model(data_source=args.data_source, n_epochs=args.n_epochs) diff --git a/training/static/.gitignore b/training/static/.gitignore new file mode 100644 index 0000000..e167a34 --- /dev/null +++ b/training/static/.gitignore @@ -0,0 +1,2 @@ +# Anything is game in this folder +!* diff --git a/training/tex/.gitignore b/training/tex/.gitignore new file mode 100644 index 0000000..841b7c7 --- /dev/null +++ b/training/tex/.gitignore @@ -0,0 +1,2 @@ +# Don't track TeX temporaries +*latexindent* \ No newline at end of file diff --git a/training/tex/bib.bib b/training/tex/bib.bib new file mode 100644 index 0000000..5ca6881 --- /dev/null +++ b/training/tex/bib.bib @@ -0,0 +1,37 @@ +@article{Hunter:2007, + Author = {Hunter, J. D.}, + Title = {Matplotlib: A 2D graphics environment}, + Journal = {Computing in Science \& Engineering}, + Volume = {9}, + Number = {3}, + Pages = {90--95}, + abstract = {Matplotlib is a 2D graphics package used for Python for + application development, interactive scripting, and publication-quality + image generation across user interfaces and operating systems.}, + publisher = {IEEE COMPUTER SOC}, + doi = {10.1109/MCSE.2007.55}, + year = 2007 + } + +@article{ harris2020array, + title = {Array programming with {NumPy}}, + author = {Charles R. Harris and K. Jarrod Millman and St{\'{e}}fan J. + van der Walt and Ralf Gommers and Pauli Virtanen and David + Cournapeau and Eric Wieser and Julian Taylor and Sebastian + Berg and Nathaniel J. Smith and Robert Kern and Matti Picus + and Stephan Hoyer and Marten H. van Kerkwijk and Matthew + Brett and Allan Haldane and Jaime Fern{\'{a}}ndez del + R{\'{i}}o and Mark Wiebe and Pearu Peterson and Pierre + G{\'{e}}rard-Marchant and Kevin Sheppard and Tyler Reddy and + Warren Weckesser and Hameer Abbasi and Christoph Gohlke and + Travis E. Oliphant}, + year = {2020}, + month = sep, + journal = {Nature}, + volume = {585}, + number = {7825}, + pages = {357--362}, + doi = {10.1038/s41586-020-2649-2}, + publisher = {Springer Science and Business Media {LLC}}, + url = {https://doi.org/10.1038/s41586-020-2649-2} +} \ No newline at end of file diff --git a/training/tex/figures/.gitignore b/training/tex/figures/.gitignore new file mode 100644 index 0000000..9d0f65c --- /dev/null +++ b/training/tex/figures/.gitignore @@ -0,0 +1,5 @@ +# Nothing should be tracked in this folder... +* + +# Except the gitignore file itself! +!.gitignore \ No newline at end of file diff --git a/training/tex/ms.tex b/training/tex/ms.tex new file mode 100644 index 0000000..ad2222b --- /dev/null +++ b/training/tex/ms.tex @@ -0,0 +1,254 @@ + +\documentclass[twocolumn]{aastex631} + +% Import showyourwork magic +\usepackage{showyourwork} + +\usepackage[utf8]{inputenc} +\usepackage{amsmath} +\usepackage{unicode-math} + + +% Recommended, but optional, packages for figures and better typesetting: +\usepackage{microtype} +\usepackage{graphicx} +\usepackage{subfigure} +\usepackage{booktabs} % for professional tables +\usepackage{multirow} + +% hyperref makes hyperlinks in the resulting PDF. +% xurl can wrap the link if it spans a column (especially in citations). +\usepackage{hyperref} +\usepackage{xurl} + + +% This command creates a new command \editor{} that highlights any of the text in {} with a maroon color, so it can easily be spotted during internal review + +\usepackage[textsize=tiny]{todonotes} +\newcommand{\editor}[1]{{\color{purple} #1}} + +\begin{document} + +\title{DeepSkies - Template} % Define the title itself, so it may be used in headers + +\author{Author 1 \thanks{Corresponding Author, email@domain.com}} + + +\begin{abstract} + This document is meant to be used as a lose guide. + It includes useful and basic packages and formatting tips to keep you from hunting for formatting code while writing. + Please use this as a reference, and especially while writing without a specific journal already in mind. + This will not be the format all journals accept, so please use their defined style guides when work on your draft. + % Additionally, it's very nice to keep all your sentences on different lines. + % It makes editing a lot easier. +\end{abstract} + +\section{Basic Format and Style} + +\subsection{Format} + +The specific format of the paper if between you and your journal and your editors. +However, it is a good idea to include the basic sections of "Introduction, Methods, Conclusions". + +\subsection{Style} + +Names of coding packages denoted with: \texttt{Package}. + + + +\editor{Here is an quick comment that may appear, indicating an addition by an editor.} + +\subsubsection{Tables} + +Tables should act as summaries, and include error bars when applicable. Captions should draw attention to the main takeaway and can provide analysis, but not necessary give a full summary. +Please view sample table formats in the appendix ~\ref{tab:two_column} + + + + +\subsubsection{Plots and other graphics} + +When making graphics, please keep accessibility in mind. +All plots should be understandable in both black and white and color. +This requires things like using color blind friendly color packages (matplotlib's virdis for example), and changing line and marker styles for different elements of a graph. +Plots also must be clearly labeled and include legends where applicable. +Captions should both describe what the figure contains and its significance. + +When referencing a figure in the main text, please refer to it with \verb|~\ref{figure label}|. +Please view different figure layouts in the appendix ~\ref{fig:single_graphic_figure}. + + + +\subsubsection{Equations} + +Large equations should be numbered and included in an equation block such that +\begin{align} + E=mc^2 \label{eq:1} \\ + F=ma \label{eq:2} +\end{align} + + +Intermediate steps can not include numbers such that +\begin{align*} + A = \pi r^2 +\end{align*} + +Or by using: + +\begin{align} + A + &=B \label{eq:3}\\ + &=B \notag\\ + A + &=BCD \label{eq:4}\\ + &=B \notag +\end{align} + + +Labels are used so that they can be referenced later on using the command \verb|~\ref{eq:equation label}|. Singular symbols can be added into the middle of sentences using \verb|$\symbol$|, such that \verb|\pi| becomes $\pi$. + +\section {Acknowledgements} + +Make sure to cite \cite{harris2020array} all of your sources \cite{Hunter:2007}. + + +You can also optionally provide contributions by person: + +\paragraph{Author 1} +Author 1 contributed X Y and Z + +\paragraph{Author 2} +Author 2 contributed A B and C + +If you work with the DeepSkies research group; please include the following text: + +\emph{We acknowledge the Deep Skies Lab as a community of multi-domain experts and collaborators who’ve facilitated an environment of open discussion, idea-generation, and collaboration. This community was important for the development of this project.} + + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% bibliography +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +% Style of the bib may change based on the publications requirements + +\bibliography{bib} + + + % Ending the multicol format before the appendix + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% APPENDIX +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +\newpage +\appendix +\section{Appendix} +You may include an appendix, it contains extra tables not required to understand the main body, but helpful references. + +\subsection{Figure References} +\begin{figure}[h] + \centering + \includegraphics[scale=.1] + {figures/frog.jpg} + \caption{ + This is a figure (containing a cute, although not colorblind friendly, frog) with a single graphic. + Because the original image is very large, it is resized with a smaller scale. + } + \label{fig:single_graphic_figure} +\end{figure} + + +\begin{figure}[h] + \begin{center} + \begin{minipage}{.35\linewidth} + \includegraphics[width=\linewidth]{figures/frog2.jpg} + + \caption{An example of using minipage to caption each image in a combined figure separately.} + \end{minipage}\hfill + + \begin{minipage}{.35\linewidth} + \includegraphics[width=\linewidth]{figures/frog3.jpg} + + \caption{This frog has it's own caption, so they can be referred to separately If you were heartless enough to separate them.} + \end{minipage} + \label{multifigAB} + + \end{center} + +\end{figure} + +% Todo Example of running show your work function within the tex to produce table + +\subsection{Table References} + +\begin{figure}[h] + \centering + \mbox{\subfigure{\includegraphics[width=.35\linewidth]{figures/frog2.jpg}}\quad + \subfigure{\includegraphics[width=.35\linewidth]{figures/frog3.jpg} }} + \caption{An example showing two images with a shared caption using subfigure. Now the frogs cannot be separated.} + \label{fig:multifigC} +\end{figure} + +\begin{table}[h] + \centering + \caption{Sample table with two columns and a header, with the caption placed on top.} + \label{tab:two_column} + \vspace{.2in} + \begin{tabular}{c | c} + \toprule + Header 1 & Header 2 \\ + \midrule + Entry 1 & 0 $\pm$ 0.001 \\ + Entry 2 & 1 $\pm$ 0.001 \\ + Entry 3 & 2 $\pm$ 0.001 \\ + \bottomrule + \end{tabular} +\end{table} + +\begin{table}[h] + \centering + \caption{A Table displaying multi-rows. Horizontal lines can be removed, but tend to lead to confusing tables.} + \vspace{.2in} + \label{tab:multirow} + \begin{tabular}{c|c|c} + + \toprule + Header 1 & Header 2 & Header 3 \\ + \midrule + + \multirow{2}*{Multi-Row} + & Row 1 & Row 1 \\ + \cline{2-3} % \cline{n_rows-n_columns} + & Row 2 & Row 2 \\ + + + \hline + Single-Row & Row 3 & Row 3\\ + \bottomrule + \end{tabular} + +\end{table} + +\begin{table}[h] + \centering + \caption{A Table with multiple columns.} + \label{tab:multicol} + \vspace{.2in} + + \begin{tabular}{c|c|c} + \toprule + \multicolumn{2}{c|}{Multi-Column} & Column 3 \\ + \midrule + Column 1 & Column 2 & Column 3 \\ + Column 1 & Column 2 & Column 3 \\ + \bottomrule + \end{tabular} +\end{table} + +% Todo: Show your work table drawing results from a function + +\end{document} diff --git a/training/tex/output/.gitignore b/training/tex/output/.gitignore new file mode 100644 index 0000000..9d0f65c --- /dev/null +++ b/training/tex/output/.gitignore @@ -0,0 +1,5 @@ +# Nothing should be tracked in this folder... +* + +# Except the gitignore file itself! +!.gitignore \ No newline at end of file diff --git a/training/tex/showyourwork.sty b/training/tex/showyourwork.sty new file mode 100644 index 0000000..8432d72 --- /dev/null +++ b/training/tex/showyourwork.sty @@ -0,0 +1,13 @@ +\NeedsTeXFormat{LaTeX2e} +\ProvidesPackage{showyourwork}[2022/01/12 Open source science articles] + +\IfFileExists{./showyourwork.tex}{ + \input{showyourwork.tex} +}{ + \newcommand\GitHubURL{} + \newcommand\GitHubSHA{} + \newcommand\GitHubIcon{} + \newcommand\showyourwork{} + \newcommand\script[1]{} + \newcommand\variable[1]{} +} \ No newline at end of file