Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 9, 2024
1 parent 59b0565 commit 567a73a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
21 changes: 11 additions & 10 deletions docs/partial-inputs-flood-tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@
"multi_model = CLAYModule.load_from_checkpoint(\n",
" CKPT_PATH,\n",
" mask_ratio=0.0,\n",
" band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4,5)},\n",
" band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4, 5)},\n",
" bands=6,\n",
" strict=False, # ignore the extra parameters in the checkpoint\n",
")\n",
Expand All @@ -436,7 +436,6 @@
" 2893.86, # nir\n",
" 2303.00, # swir16\n",
" 1807.79, # swir22\n",
" \n",
" ]\n",
" STD = [\n",
" 2026.96, # red\n",
Expand Down Expand Up @@ -481,7 +480,7 @@
" batch[\"pixels\"] = batch[\"pixels\"].to(multi_model.device)\n",
" # Pass just the specific band through the model\n",
" batch[\"timestep\"] = batch[\"timestep\"].to(multi_model.device)\n",
" batch[\"timestep_non_norm\"] = batch[\"timestep_non_norm\"] #.to(rgb_model.device)\n",
" batch[\"timestep_non_norm\"] = batch[\"timestep_non_norm\"] # .to(rgb_model.device)\n",
" batch[\"latlon\"] = batch[\"latlon\"].to(multi_model.device)\n",
"\n",
" # Pass pixels, latlon, timestep through the encoder to create encoded patches\n",
Expand Down Expand Up @@ -606,7 +605,6 @@
"outputs": [],
"source": [
"from sklearn.manifold import TSNE\n",
"from sklearn.ensemble import IsolationForest\n",
"\n",
"# List to store date tuples\n",
"all_date_tuples = []\n",
Expand All @@ -617,8 +615,11 @@
" year_tensor, month_tensor, day_tensor = tensor_list\n",
"\n",
" # Reshape into list of tuples\n",
" date_tuples = [(int(year), int(month), int(day)) for year, month, day in zip(year_tensor, month_tensor, day_tensor)]\n",
" \n",
" date_tuples = [\n",
" (int(year), int(month), int(day))\n",
" for year, month, day in zip(year_tensor, month_tensor, day_tensor)\n",
" ]\n",
"\n",
" # Append to the list of all date tuples\n",
" all_date_tuples.extend(date_tuples)\n",
"\n",
Expand Down Expand Up @@ -651,11 +652,11 @@
"\n",
"# Annotate each point with the corresponding date\n",
"for i, (x, y) in enumerate(zip(X_tsne[:, 0], X_tsne[:, 1])):\n",
" plt.annotate(f'{all_date_tuples[i]}', (x, y))\n",
" plt.annotate(f\"{all_date_tuples[i]}\", (x, y))\n",
"\n",
"plt.title('t-SNE Visualization')\n",
"plt.xlabel('t-SNE Component 1')\n",
"plt.ylabel('t-SNE Component 2')\n",
"plt.title(\"t-SNE Visualization\")\n",
"plt.xlabel(\"t-SNE Component 1\")\n",
"plt.ylabel(\"t-SNE Component 2\")\n",
"plt.show()"
]
},
Expand Down
10 changes: 9 additions & 1 deletion src/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["GDAL_HTTP_MERGE_CONSECUTIVE_RANGES"] = "YES"


# %%
# Regular torch Dataset
class ClayDataset(Dataset):
Expand Down Expand Up @@ -60,7 +61,14 @@ def read_chip(self, chip_path):

# read timestep & normalize
date = chip.tags()["date"] # YYYY-MM-DD
year, month, day, year_non_norm, month_non_norm, day_non_norm = self.normalize_timestamp(date)
(
year,
month,
day,
year_non_norm,
month_non_norm,
day_non_norm,
) = self.normalize_timestamp(date)

# read lat,lon from UTM to WGS84 & normalize
bounds = chip.bounds # xmin, ymin, xmax, ymax
Expand Down

0 comments on commit 567a73a

Please sign in to comment.