diff --git a/.gitignore b/.gitignore index 7192a4a5..77021954 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ logs/ # Data files and folders data/** !data/**/ +**/*/.gif +**/*/.png # Distribution / packaging .Python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c34ef772..2aaac2ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ repos: rev: v4.5.0 hooks: - id: check-added-large-files + args: [ '--maxkb=512', '--enforce-all' ] - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..f5bc0086 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,20 @@ +# Clay Model Documentation + +This Documentation uses [Jupyter Book](https://jupyterbook.org/intro.html). + +Install it with: +```bash +pip install -U jupyter-book +``` + +Then build it with: +```bash +jupyter-book build docs/ +``` + +You can preview the site locally with: +```bash +python -m http.server --directory _build/html +``` + +There is a GitHub Action on `./github/workflows/deploy-docs.yml` that builds the site and pushes it to GitHub Pages. diff --git a/docs/_config.yml b/docs/_config.yml index bfcbd85c..1cb5dc24 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -4,11 +4,14 @@ title: Clay Foundation Model author: Clay Foundation logo: logo.png +only_build_toc_files: true -# Force re-execution of notebooks on each build. +# Only execution notebooks with no output cells on each build. # See https://jupyterbook.org/content/execute.html execute: - execute_notebooks: force + execute_notebooks: cache + exclude_patterns: + - clay-v0-*.ipynb # Define the name of the latex output file for PDF builds latex: diff --git a/docs/_toc.yml b/docs/_toc.yml index 83ade51e..dc20a3e5 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -2,12 +2,20 @@ # Learn more at https://jupyterbook.org/customize/toc.html format: jb-book -root: intro +root: index parts: +- caption: Release notes + chapters: + - title: Software release notes + file: changelog + - title: Model release notes + file: specification - caption: Getting Started chapters: - title: Installation file: installation + - title: Basic Use + file: basic_use - caption: Data Preparation chapters: - title: Creating datacubes @@ -22,13 +30,19 @@ parts: file: model_embeddings - title: Finetuning file: model_finetuning -- caption: Reference Documentation +- caption: Tutorials chapters: - - title: Changelog - file: changelog + - title: Generative AI for pixel reconstruction + file: clay-v0-reconstruction + - title: Create location embeddings + file: clay-v0-location-embeddings + - title: Interpolating images in embedding space + file: clay-v0-interpolation - caption: About Clay chapters: - title: GitHub url: https://github.com/Clay-foundation - title: LinkedIn url: https://www.linkedin.com/company/made-with-clay + - title: Website + url: https://madewithclay.org diff --git a/docs/basic_use.md b/docs/basic_use.md new file mode 100644 index 00000000..79c1ab66 --- /dev/null +++ b/docs/basic_use.md @@ -0,0 +1,42 @@ +# Basic Use + +### Running jupyter lab + + mamba activate claymodel + python -m ipykernel install --user --name claymodel # to install virtual env properly + jupyter kernelspec list --json # see if kernel is installed + jupyter lab & + + +### Running the model + +The neural network model can be ran via +[LightningCLI v2](https://pytorch-lightning.medium.com/introducing-lightningcli-v2supercharge-your-training-c070d43c7dd6). +To check out the different options available, and look at the hyperparameter +configurations, run: + + python trainer.py --help + python trainer.py test --print_config + +To quickly test the model on one batch in the validation set: + + python trainer.py validate --trainer.fast_dev_run=True + +To train the model for a hundred epochs: + + python trainer.py fit --trainer.max_epochs=100 + +To generate embeddings from the pretrained model's encoder on 1024 images +(stored as a GeoParquet file with spatiotemporal metadata): + + python trainer.py predict --ckpt_path=checkpoints/last.ckpt \ + --data.batch_size=1024 \ + --data.data_dir=s3://clay-tiles-02 \ + --trainer.limit_predict_batches=1 + +More options can be found using `python trainer.py fit --help`, or at the +[LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). + +## Advanced + +See [Readme](https://github.com/Clay-foundation/model/blob/v0.0.1/README.md) on model root for more details. diff --git a/docs/changelog.md b/docs/changelog.md index 0b549605..7d69439d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,11 @@ -# Changelog +(software_release)= +# Code Model release v0.0.1 + +This changelog is a summary of the changes to the source code of the Clay model. +Released on 2024/01/12. + +> For release notes for the trained model, see [](model_release) -## Release v0.0.1 (2024/01/12) ### 💫 Highlights diff --git a/docs/clay-v0-interpolation.ipynb b/docs/clay-v0-interpolation.ipynb new file mode 100644 index 00000000..d65481dc --- /dev/null +++ b/docs/clay-v0-interpolation.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "76ed0078-447f-4374-b6ba-a8b4a366188d", + "metadata": {}, + "source": [ + "# CLAY v0 - Interpolation between images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"../\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "import imageio\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from einops import rearrange\n", + "from PIL import Image\n", + "\n", + "from src.datamodule import ClayDataModule, ClayDataset\n", + "from src.model_clay import CLAYModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37f4a735-18e6-48d7-9b58-e8d188e96b54", + "metadata": {}, + "outputs": [], + "source": [ + "# data directory for all chips\n", + "DATA_DIR = \"../data/02\"\n", + "# path of best model checkpoint for Clay v0\n", + "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"" + ] + }, + { + "cell_type": "markdown", + "id": "4c300730-b0b0-4c3d-8a0d-d5e3ac018641", + "metadata": {}, + "source": [ + "## Load Model & DataModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model & set in eval mode\n", + "model = CLAYModule.load_from_checkpoint(\n", + " CKPT_PATH, mask_ratio=0.0, shuffle=False\n", + ") # No masking or shuffling of patches\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "348c0573-7670-47a6-9e13-c6de36493b58", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = Path(DATA_DIR)\n", + "\n", + "# Load the Clay DataModule\n", + "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", + "dm = ClayDataModule(data_dir=str(data_dir), batch_size=2)\n", + "dm.setup(stage=\"fit\")\n", + "\n", + "# Load the train DataLoader\n", + "trn_dl = iter(dm.train_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af3b38f7-f193-4b7d-aada-b4c8abcce7ed", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the first batch of chips\n", + "batch = next(trn_dl)\n", + "batch.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59f9a028-a789-40c9-8f23-eb2a0e1c66eb", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"pixels\"].shape, batch[\"latlon\"].shape, batch[\"timestep\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e4c53bf-43a0-4e19-93da-bc47bb446e74", + "metadata": {}, + "outputs": [], + "source": [ + "def show(sample, idx=None, save=False):\n", + " Path(\"animate\").mkdir(exist_ok=True)\n", + " sample = rearrange(sample, \"c h w -> h w c\")\n", + " denorm_sample = sample * torch.as_tensor(dm.STD) + torch.as_tensor(dm.MEAN)\n", + " rgb = denorm_sample[..., [2, 1, 0]]\n", + " plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min()))\n", + " plt.axis(\"off\")\n", + " if save:\n", + " plt.savefig(f\"animate/chip_{idx}.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e22e6b-127e-4948-b3e1-12b81164a417", + "metadata": {}, + "outputs": [], + "source": [ + "sample1, sample2 = batch[\"pixels\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9648f698-7970-4f42-a9d5-093661139994", + "metadata": {}, + "outputs": [], + "source": [ + "show(sample1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbd75d9a-b42a-4e92-b81c-705a98a07d41", + "metadata": {}, + "outputs": [], + "source": [ + "show(sample2)" + ] + }, + { + "cell_type": "markdown", + "id": "4d52f962-d20d-4966-af4d-fd5509520356", + "metadata": {}, + "source": [ + "Each batch has chips of shape `13 x 512 x 512`, normalized `lat` & `lon` coords & normalized timestep information as `year`, `month` & `day`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", + "metadata": {}, + "outputs": [], + "source": [ + "# Save a copy of batch to visualize later\n", + "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "e7c46bb7-3e25-454d-b345-1ca3ec1efb69", + "metadata": {}, + "source": [ + "## Pass data through the CLAY model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "092d1ded-427f-424f-82ec-63bf0bccfdcc", + "metadata": {}, + "outputs": [], + "source": [ + "# Pass the pixels through the encoder & decoder of CLAY\n", + "with torch.no_grad():\n", + " # Move data from to the device of model\n", + " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", + " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", + " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", + "\n", + " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", + " (\n", + " unmasked_patches,\n", + " unmasked_indices,\n", + " masked_indices,\n", + " masked_matrix,\n", + " ) = model.model.encoder(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "5c14323c-dc02-4255-aa1a-3ec3c8799e3f", + "metadata": {}, + "source": [ + "### Create an image based on interpolation of the embedding values between 2 images\n", + "*Images are saved inside `./animate`*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6386a7a1-d225-4c1e-8228-23f5ce4b87e4", + "metadata": {}, + "outputs": [], + "source": [ + "for idx, alpha in enumerate(np.linspace(0, 1, 20)):\n", + " patch_break = 128\n", + " l1, l2 = unmasked_patches\n", + " l3 = alpha * l1 + (1 - alpha) * l2\n", + " l4 = torch.vstack((l1[:patch_break, :], l2[patch_break:, :]))\n", + "\n", + " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", + " with torch.no_grad():\n", + " pixels = model.model.decoder(\n", + " rearrange(l3, \"gl d -> 1 gl d\"), unmasked_indices[[0]], masked_indices[[0]]\n", + " )\n", + "\n", + " image = rearrange(pixels, \"b c (h w) (p1 p2) -> b c (h p1) (w p2)\", h=16, p1=32)\n", + " _image = image[0].detach().cpu()\n", + " show(_image, idx, save=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea7d8627-252d-42e2-af29-0fccb611121e", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(2, 10, figsize=(20, 4))\n", + "for ax, idx in zip(axs.flatten(), range(20)):\n", + " ax.imshow(Image.open(f\"./animate/chip_{idx}.png\"))\n", + " ax.set_title(f\"Seq {idx}\")\n", + " ax.set_axis_off()\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "91606fa2-deec-485a-815c-d8c7ba07dec3", + "metadata": {}, + "source": [ + "#### Create a GIF of the interpolation of images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91d9aeb7-edc3-492d-a6f1-fd5309e2ab40", + "metadata": {}, + "outputs": [], + "source": [ + "img_paths = [f\"./animate/chip_{idx}.png\" for idx in range(20)]\n", + "\n", + "with imageio.get_writer(\"animate/sample.gif\", mode=\"I\", duration=100) as writer:\n", + " for img_path in img_paths:\n", + " img = imageio.imread(img_path)\n", + " writer.append_data(img)\n", + "\n", + "# Delete the images\n", + "for img_path in img_paths:\n", + " os.remove(img_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e458cd3-d1a8-41b2-9e85-7612c153a7ea", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image, display\n", + "\n", + "display(Image(filename=\"./animate/sample.gif\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "730bbb55-344d-49e4-a81a-59a9da014770", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/clay-v0-location-embeddings.ipynb b/docs/clay-v0-location-embeddings.ipynb new file mode 100644 index 00000000..ad412cde --- /dev/null +++ b/docs/clay-v0-location-embeddings.ipynb @@ -0,0 +1,459 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e66cec10-75b7-4a14-b9b4-91f4d707bb6d", + "metadata": {}, + "source": [ + "# CLAY v0 - Location Embeddings" + ] + }, + { + "cell_type": "markdown", + "id": "65130e67-0868-4e6e-b181-4c456223f998", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"../\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "from pathlib import Path\n", + "\n", + "import lightning as L\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import rasterio as rio\n", + "import torch\n", + "from sklearn.cluster import KMeans\n", + "from sklearn.decomposition import PCA\n", + "\n", + "from src.datamodule import ClayDataModule, ClayDataset\n", + "from src.model_clay import CLAYModule\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65b3babb-ab89-40f1-920c-ac9b88dc9738", + "metadata": {}, + "outputs": [], + "source": [ + "L.seed_everything(42)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b12b28e1-c8c6-470d-8b38-5600e4897074", + "metadata": {}, + "outputs": [], + "source": [ + "# data directory for all chips\n", + "DATA_DIR = \"../data/02\"\n", + "# path of best model checkpoint for Clay v0\n", + "CKPT_PATH = \"../checkpoints/v0/mae_epoch-24_val-loss-0.46.ckpt\"" + ] + }, + { + "cell_type": "markdown", + "id": "7bade738-5ed9-4530-a9a6-8ade3cb4d6d8", + "metadata": {}, + "source": [ + "## Load Model & DataModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model & set in eval mode\n", + "model = CLAYModule.load_from_checkpoint(CKPT_PATH, mask_ratio=0.7)\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "348c0573-7670-47a6-9e13-c6de36493b58", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = Path(DATA_DIR)\n", + "\n", + "# Load the Clay DataModule\n", + "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", + "dm = ClayDataModule(data_dir=str(data_dir), batch_size=100)\n", + "dm.setup(stage=\"fit\")\n", + "\n", + "# Load the train DataLoader\n", + "trn_dl = iter(dm.train_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fd6d1ca-4cde-48c9-842a-8eb16279a534", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the first batch of chips\n", + "batch = next(trn_dl)\n", + "batch.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", + "metadata": {}, + "outputs": [], + "source": [ + "# Save a copy of batch to visualize later\n", + "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "db087380-c94c-40a2-8746-2e33b7903b3d", + "metadata": {}, + "source": [ + "## Pass model through the CLAY model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65424c5e-f68d-47b0-bf87-030d3cf9b7e8", + "metadata": {}, + "outputs": [], + "source": [ + "# Pass the pixels through the encoder & decoder of CLAY\n", + "with torch.no_grad():\n", + " # Move data from to the device of model\n", + " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", + " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", + " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", + "\n", + " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", + " (\n", + " unmasked_patches,\n", + " unmasked_indices,\n", + " masked_indices,\n", + " masked_matrix,\n", + " ) = model.model.encoder(batch)\n", + "\n", + " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", + " pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)" + ] + }, + { + "cell_type": "markdown", + "id": "c601956c-d18c-4cdf-a4fe-0a30e4b06786", + "metadata": {}, + "source": [ + "## Extract Location & Timestep Embeddings" + ] + }, + { + "cell_type": "markdown", + "id": "a8160012-d0d8-4eb8-b87a-f5c02cd980d6", + "metadata": {}, + "source": [ + "In CLAY, the encoder receives unmasked patches, latitude-longitude data, and timestep information. Notably, the last 2 embeddings from the encoder specifically represent the latitude-longitude and timestep embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baa7adfb-3503-415b-a892-00edcfe1f127", + "metadata": {}, + "outputs": [], + "source": [ + "latlon_embeddings = unmasked_patches[:, -2, :].detach().cpu().numpy()\n", + "time_embeddings = unmasked_patches[:, -1, :].detach().cpu().numpy()\n", + "\n", + "# Get normalized latlon that were input to the model\n", + "latlon = batch[\"latlon\"].detach().cpu().numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "da8ce686-9982-4845-a171-2729365ead8e", + "metadata": {}, + "source": [ + "We will just focus on location embeddings in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "591d1d55-a335-4925-8f45-48c9584106d7", + "metadata": {}, + "outputs": [], + "source": [ + "latlon.shape, latlon_embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "3384f479-ef84-420d-a4e9-e3b038f05497", + "metadata": {}, + "source": [ + "> Latitude & Longitude map to 768 dimentional vector" + ] + }, + { + "cell_type": "markdown", + "id": "9e419fc9-e7d3-49de-a8ea-72912c365510", + "metadata": {}, + "source": [ + "## Preform PCA over the location embeddings to visualize them in 2 dimension" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b7fc190-25c2-47b6-9e0a-cf7bc9d558d7", + "metadata": {}, + "outputs": [], + "source": [ + "pca = PCA(n_components=2)\n", + "latlon_embeddings = pca.fit_transform(latlon_embeddings)\n", + "latlon_embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "92c92119-282a-43ca-aad4-a01f2d7c278b", + "metadata": {}, + "source": [ + "## Create clusters of normalized latlon & latlon embeddings to check if there are any learned patterns in them after training" + ] + }, + { + "cell_type": "markdown", + "id": "6d5127bd-e8c1-4ddc-a135-4ddcafa974cf", + "metadata": {}, + "source": [ + "Latlon Cluster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "330ae1a0-f94f-4d97-8a3f-eb507773585b", + "metadata": {}, + "outputs": [], + "source": [ + "kmeans = KMeans(n_clusters=5)\n", + "kmeans.fit_transform(latlon)\n", + "latlon = np.column_stack((latlon, kmeans.labels_))" + ] + }, + { + "cell_type": "markdown", + "id": "d753e7e2-3550-431d-bc65-f8e2da857be1", + "metadata": {}, + "source": [ + "Latlon Embeddings Cluster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f95256c1-65f3-4653-80e9-9c7ad31eb475", + "metadata": {}, + "outputs": [], + "source": [ + "kmeans = KMeans(n_clusters=5)\n", + "kmeans.fit_transform(latlon_embeddings)\n", + "latlon_embeddings = np.column_stack((latlon_embeddings, kmeans.labels_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "670c84f6-8041-4643-8c80-e039f2aa4400", + "metadata": {}, + "outputs": [], + "source": [ + "latlon.shape, latlon_embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "94e3739f-38c6-40cf-9457-904cd6c56324", + "metadata": {}, + "source": [ + "> We are a third dimension to latlon & latlon embeddings with cluster labels" + ] + }, + { + "cell_type": "markdown", + "id": "a259f9d0-05dc-4dab-bf49-c18a510d3d3a", + "metadata": {}, + "source": [ + "## Plot latlon clusters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "142a72ee-f88a-4cf6-bdd5-3e24fa7bd3aa", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(15, 15), dpi=80)\n", + "plt.scatter(latlon[:, 0], latlon[:, 1], c=latlon[:, 2], label=\"Actual\", alpha=0.3)\n", + "\n", + "for i in range(100):\n", + " txt = f\"{latlon[:,0][i]:.2f},{latlon[:, 1][i]:.2f}\"\n", + " plt.annotate(txt, (latlon[:, 0][i] + 1e-5, latlon[:, 1][i] + 1e-5))" + ] + }, + { + "cell_type": "markdown", + "id": "91e1a739-9536-4401-98ed-6285ff51b09a", + "metadata": {}, + "source": [ + "> As we see in the scatter plot above, there is nothing unique about latlon that go into the model, they are cluster based on their change in longitude values above" + ] + }, + { + "cell_type": "markdown", + "id": "df76148a-b05f-4f52-8b55-1ea5ecac9f84", + "metadata": {}, + "source": [ + "## Plot latlon embeddings cluster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5d51daa-1541-49a4-9b1d-285189136d34", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(15, 15), dpi=80)\n", + "plt.scatter(\n", + " latlon_embeddings[:, 0],\n", + " latlon_embeddings[:, 1],\n", + " c=latlon_embeddings[:, 2],\n", + " label=\"Predicted\",\n", + " alpha=0.3,\n", + ")\n", + "for i in range(100):\n", + " txt = i\n", + " plt.annotate(txt, (latlon_embeddings[:, 0][i], latlon_embeddings[:, 1][i]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ebbad29-4abf-4a4b-9c94-455edc8dac40", + "metadata": {}, + "outputs": [], + "source": [ + "def show_cluster(ids):\n", + " fig, axes = plt.subplots(1, len(ids), figsize=(10, 5))\n", + " for i, ax in zip(ids, axes.flatten()):\n", + " img_path = batch[\"source_url\"][i]\n", + " img = rio.open(img_path).read([3, 2, 1]).transpose(1, 2, 0)\n", + " img = (img - img.min()) / (img.max() - img.min())\n", + " ax.imshow(img)\n", + " ax.set_axis_off()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "797f9b1c-17db-45c3-abcd-e7b16886f591", + "metadata": {}, + "outputs": [], + "source": [ + "show_cluster((87, 37, 40))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ae6b021-24e9-408b-a558-f322f4e7bc2d", + "metadata": {}, + "outputs": [], + "source": [ + "show_cluster((23, 11, 41))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "224d210b-642c-4ad4-acc0-f1b009a7f0a5", + "metadata": {}, + "outputs": [], + "source": [ + "show_cluster((68, 71, 7))" + ] + }, + { + "cell_type": "markdown", + "id": "7c8d1fde-fefc-4e4c-8518-f803411baa01", + "metadata": {}, + "source": [ + "> We can see location embedding capturing semantic information as well" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4910938d-611d-47a8-abc3-e94e3e152dd1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/clay-v0-reconstruction.ipynb b/docs/clay-v0-reconstruction.ipynb new file mode 100644 index 00000000..dbbdfd65 --- /dev/null +++ b/docs/clay-v0-reconstruction.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "76ed0078-447f-4374-b6ba-a8b4a366188d", + "metadata": {}, + "source": [ + "# CLAY v0 - Quality of reconstruction by the model " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"../\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import einops\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import torch\n", + "from einops import rearrange\n", + "\n", + "from src.datamodule import ClayDataModule, ClayDataset\n", + "from src.model_clay import CLAYModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37f4a735-18e6-48d7-9b58-e8d188e96b54", + "metadata": {}, + "outputs": [], + "source": [ + "# data directory for all chips\n", + "DATA_DIR = \"../data/02\"\n", + "# path of best model checkpoint for Clay v0\n", + "CKPT_PATH = \"../checkpoints/v0/mae_epoch-24_val-loss-0.46.ckpt\"" + ] + }, + { + "cell_type": "markdown", + "id": "4c300730-b0b0-4c3d-8a0d-d5e3ac018641", + "metadata": {}, + "source": [ + "## Load Model & DataModule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model & set in eval mode\n", + "model = CLAYModule.load_from_checkpoint(CKPT_PATH, mask_ratio=0.7)\n", + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "348c0573-7670-47a6-9e13-c6de36493b58", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = Path(DATA_DIR)\n", + "\n", + "# Load the Clay DataModule\n", + "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", + "dm = ClayDataModule(data_dir=str(data_dir), batch_size=8)\n", + "dm.setup(stage=\"fit\")\n", + "\n", + "# Load the train DataLoader\n", + "trn_dl = iter(dm.train_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af3b38f7-f193-4b7d-aada-b4c8abcce7ed", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the first batch of chips\n", + "batch = next(trn_dl)\n", + "batch.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59f9a028-a789-40c9-8f23-eb2a0e1c66eb", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"pixels\"].shape, batch[\"latlon\"].shape, batch[\"timestep\"].shape" + ] + }, + { + "cell_type": "markdown", + "id": "4d52f962-d20d-4966-af4d-fd5509520356", + "metadata": {}, + "source": [ + "Each batch has chips of shape `13 x 512 x 512`, normalized `lat` & `lon` coords & normalized timestep information as `year`, `month` & `day`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", + "metadata": {}, + "outputs": [], + "source": [ + "# Save a copy of batch to visualize later\n", + "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e7c46bb7-3e25-454d-b345-1ca3ec1efb69", + "metadata": {}, + "source": [ + "## Pass data through the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bed50b3-a481-44ba-b04f-440ba5fb57ed", + "metadata": {}, + "outputs": [], + "source": [ + "# Pass the pixels through the encoder & decoder of CLAY\n", + "with torch.no_grad():\n", + " # Move data from to the device of model\n", + " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", + " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", + " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", + "\n", + " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", + " (\n", + " unmasked_patches,\n", + " unmasked_indices,\n", + " masked_indices,\n", + " masked_matrix,\n", + " ) = model.model.encoder(batch)\n", + "\n", + " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", + " pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb073da6-149b-4dab-b9f8-8ae5c5dd5178", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " unmasked_patches.shape,\n", + " unmasked_indices.shape,\n", + " masked_indices.shape,\n", + " masked_matrix.shape,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b1a36db-61a8-4cfe-8970-9ce2142fb2bd", + "metadata": {}, + "outputs": [], + "source": [ + "# Reconstructed chips from 70% masked inputs to the model\n", + "pixels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4ed496a-1a38-4284-befc-0e74065595e0", + "metadata": {}, + "outputs": [], + "source": [ + "# Rearrange the pixels into chips of size `13 x 512 x 512`\n", + "pixels = rearrange(pixels, \"b c (h w) (p1 p2) -> b c (h p1) (w p2)\", h=16, p1=32)\n", + "pixels.shape" + ] + }, + { + "cell_type": "markdown", + "id": "cbeae41b-ffa4-4ca1-966d-529cfdc4c725", + "metadata": {}, + "source": [ + "## Plot the pixel reconstructions from the CLAY model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f46be28-514b-4089-a592-d851ccd2da31", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_pixel_reconstruction():\n", + " fig, axes = plt.subplots(16, 13, figsize=(20, 20))\n", + "\n", + " for j_ in range(8):\n", + " j = j_\n", + " inp = _batch[j]\n", + " out = pixels[j].detach().cpu().numpy()\n", + " j *= 2\n", + " for i in range(13):\n", + " axes[j, i].imshow(inp[i], cmap=\"viridis\")\n", + " axes[(j + 1), i].imshow(out[i], cmap=\"viridis\")\n", + " axes[j, i].set_axis_off()\n", + " axes[(j + 1), i].set_axis_off()\n", + "\n", + " # Set column labels\n", + " cols = [\n", + " \"Blue\",\n", + " \"Green\",\n", + " \"Red\",\n", + " \"RedEdge\",\n", + " \"RedEdge\",\n", + " \"RedEdge\",\n", + " \"NIR\",\n", + " \"RedEdge\",\n", + " \"SWIR\",\n", + " \"SWIR\",\n", + " \"VV\",\n", + " \"VH\",\n", + " \"DEM\",\n", + " ]\n", + " for ax, col in zip(axes[0], cols):\n", + " ax.set_title(col)\n", + "\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "002fe404-51de-4495-8ab6-268ec94e0eeb", + "metadata": {}, + "outputs": [], + "source": [ + "plot_pixel_reconstruction()" + ] + }, + { + "cell_type": "markdown", + "id": "b390e40f-4880-4c25-8b76-992d6e049e26", + "metadata": {}, + "source": [ + "> In the figure above, each chip in the batch of eight is represented by two rows: the first row shows the actual image and the second row displays its prediction." + ] + }, + { + "cell_type": "markdown", + "id": "6a337c98-60d0-47d2-90d3-4d419cb34e7b", + "metadata": {}, + "source": [ + "## Verify quality of reconstruction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29ca1283-71cd-4a31-9136-5af1d8935d8f", + "metadata": {}, + "outputs": [], + "source": [ + "order = 3 # pick a chip from the batch of size 8\n", + "band = 1 # pick a band from 13\n", + "\n", + "# represents the group each band falls under,\n", + "# for bands 0-2: mask_order=0,\n", + "# 3,4,5,7: mask_order=1,\n", + "# 6: mask_order=2,\n", + "# 8,9: mask_order=3,\n", + "# 10,11: masked_order=4,\n", + "# 12: mask_order=5\n", + "mask_order = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95c35b54-1c95-4b97-a02d-ae716f9b8b63", + "metadata": {}, + "outputs": [], + "source": [ + "# Select one chip from the batch of inputs & reconstructed pixels\n", + "chip = batch[\"pixels\"][order].detach().cpu().numpy()\n", + "pred_chip = pixels[order].detach().cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d599a04b-25a3-43ae-8c08-c0763ecae3f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Masked matrix stores the position information of masked & unmasked patches of input\n", + "mask = masked_matrix[order]\n", + "mask = rearrange(mask[mask_order], \"(h w) -> h w\", h=16).detach().cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7322b93-ccf3-4f12-b6f5-9cfa4a900d80", + "metadata": {}, + "outputs": [], + "source": [ + "pd.DataFrame(mask).style.format(\"{:.1f}\").background_gradient(cmap=\"bwr\")" + ] + }, + { + "cell_type": "markdown", + "id": "4dde9a5c-974e-4132-9822-215f35c883e9", + "metadata": {}, + "source": [ + "> `1` represents masked patch position & `0` represents unmasked patch position " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c556442-41eb-4557-a924-d3b22092f68a", + "metadata": {}, + "outputs": [], + "source": [ + "# Scale the mask matrix to size `512 x 512`\n", + "upmask = einops.repeat(\n", + " mask, \"h w -> (h repeat_h) (w repeat_w)\", repeat_h=32, repeat_w=32\n", + ")\n", + "plt.imshow(upmask, cmap=\"bwr\")" + ] + }, + { + "cell_type": "markdown", + "id": "f5d6e897-bd10-49cd-9550-953749e06429", + "metadata": {}, + "source": [ + "> `Red`: Masked patches & `Blue`: Unmasked Patches" + ] + }, + { + "cell_type": "markdown", + "id": "5af98635-93ea-417f-af07-f276457e6200", + "metadata": {}, + "source": [ + "### Plot the quality of reconstruction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c66858e6-9bbe-4974-b042-da60ee516ea5", + "metadata": {}, + "outputs": [], + "source": [ + "# Input to the CLAY model\n", + "masked_chip = chip[band] * (1 - upmask)\n", + "\n", + "# Reconstruction from the CLAY model\n", + "recreated_chip = pred_chip[band] * upmask\n", + "recreated_chip_with_unmasked_patches = masked_chip + recreated_chip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "224d210b-642c-4ad4-acc0-f1b009a7f0a5", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n", + "\n", + "for ax in axes:\n", + " ax.set_axis_off()\n", + "\n", + "axes[0].imshow(masked_chip)\n", + "axes[0].set_title(\"Masked Input\")\n", + "\n", + "axes[1].imshow(recreated_chip)\n", + "axes[1].set_title(\"Reconstruction\")\n", + "\n", + "axes[2].imshow(recreated_chip_with_unmasked_patches)\n", + "axes[2].set_title(\"Reconstruction + Unmasked Patches\")\n", + "\n", + "axes[3].imshow(chip[band])\n", + "axes[3].set_title(\"Original Input\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e458cd3-d1a8-41b2-9e85-7612c153a7ea", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..445dccd7 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,36 @@ +# Clay Foundation Model + +## An open source AI model for Earth + +Clay is a foundational model of Earth. It uses a Vision Transformer architecture adapted +to understand geospatial and temporal relations on Earth Observation data. The model is +trained via Self-supervised learning (SSL) using a Masked Autoencoder (MAE) method. + +The Clay model can be used in three main ways: +- Generate semantic embeddings for any location and time. +- Fine-tune the model for downstream tasks such as classification, regression, and generative tasks. +- Use the model as a backbone for other models. + + +## Where is what + +- Our **website** is [madewithclay.org](https://madewithclay.org). +- The Clay model **code** lives on [Github](https://github.com/Clay-foundation/model). + License: [Apache-w.0](https://github.com/Clay-foundation/model/blob/main/LICENSE). + The latest release is [v0.0.1](https://github.com/Clay-foundation/model/releases/tag/v0.0.1) +- The Clay model **weights** on [Hugging Face](https://huggingface.co/made-with-clay/Clay/). + License: [OpenRAIL-M](https://github.com/Clay-foundation/model/blob/main/LICENSE-MODEL.md). +- The Clay **documentation** [lives on this site](https://clay-foundation.github.io/model/index.html). + License: [CC-BY](http://creativecommons.org/licenses/by/4.0/). +- *Coming Soon* > We maintain a set of **embeddings** on [Source Cooperative](https://beta.source.coop/clay/). + License: [ODC-BY](https://opendatacommons.org/licenses/by/). + +CLAY is a fiscal sponsored project of the 501c3 non-profit +[Radiant Earth Foundation](https://www.radiant.earth). + + +--- +### Table of Contents + +```{tableofcontents} +``` diff --git a/docs/installation.md b/docs/installation.md index 7bca44d5..459235e1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,5 +1,33 @@ # Installation -## Basic +## Cloud Environments + +Launch into a [JupyterLab](https://jupyterlab.readthedocs.io) environment on + +| [Binder](https://mybinder.readthedocs.io/en/latest) | [Planetary Computer](https://planetarycomputer.microsoft.com) | [SageMaker Studio Lab](https://studiolab.sagemaker.aws) | +|:--:|:--:|:--:| +| [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Clay-foundation/model/main) | [![Open on Planetary Computer](https://img.shields.io/badge/Open-Planetary%20Computer-black?style=flat&logo=microsoft)](https://pccompute.westeurope.cloudapp.azure.com/compute/hub/user-redirect/git-pull?repo=https%3A%2F%2Fgithub.com%2FClay-foundation%2Fmodel&urlpath=lab%2Ftree%2Fmodel%2Fplaceholder.ipynb&branch=main) | [![Open in SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/Clay-foundation/model/blob/main/placeholder.ipynb) | + + +## Local Environments + +Start by cloning this [repo-url](https://github.com/Clay-foundation/model) + + git clone https://github.com/Clay-foundation/model + +Then we recommend [using mamba](https://mamba.readthedocs.io/en/latest/installation.html) +to install the dependencies. + +A virtual environment will also be created with Python and +[JupyterLab](https://github.com/jupyterlab/jupyterlab) installed. + + cd model + mamba env create --file environment.yml + +Activate the virtual environment first. + + mamba activate claymodel ## Advanced + +See [Readme](https://github.com/Clay-foundation/model/blob/main/README.md) on model root for more details. diff --git a/docs/intro.md b/docs/intro.md deleted file mode 100644 index c3afc319..00000000 --- a/docs/intro.md +++ /dev/null @@ -1,13 +0,0 @@ -# Clay Foundation Model - -## An open source AI model and interface for Earth - -This is a small sample book to give you a feel for how book content is -structured. -It shows off a few of the major file types, as well as some sample content. -It does not go in-depth into any particular topic - check out [the Jupyter Book documentation](https://jupyterbook.org) for more information. - -Check out the content pages bundled with this sample book to see more. - -```{tableofcontents} -``` diff --git a/docs/specification.md b/docs/specification.md new file mode 100644 index 00000000..cdf2d55e --- /dev/null +++ b/docs/specification.md @@ -0,0 +1,173 @@ +(model_release)= +# Pretrained Model release v0.0.1 + +This changelog is a summary of the changes to the pretrained model weights for the Clay model. We follow the "Stanford [Foundation Model Transparency Index](https://github.com/stanford-crfm/fmti)" + +Model weights released on 2024/01/12. + +> For release notes for the source code, see [](software_release) + +### Summary + +Clay v0 is a self-supervised modified vision transfer model trained on stacks of Sentinel-2, Sentinel-1 & DEM data. It is trained as a Masked Autoencoder (MAE) to reconstruct the original image from a masked image. + +Each data entry is a stack of 10 bands of Sentinel-2, 2 bands of Sentinel-1 & 1 band of DEM data. The model is trained with 3 timesteps of data for each location, with a total of 1203 MGRS tiles globally distributed, each of size 10km x 10km. The data was collected from the Microsoft Planetary Computer. + +The model was trained on AWS on 4 NVIDIA A10G GPUs for 25 epochs (~14h per epoch) in December 2023. + +Model weights are available on HuggingFace [here](https://huggingface.co/made-with-clay/Clay/). + +We also generated embeddings for all trainning data, which can be found on Source Cooperative [here](https://source.coop/). + +## Model Architecture + +Clat is a Unet, with a modified ViT encoder down to embeddings, and a decoder to reconstruct the masked parts of the original image. The loss function is the MSE between the original image and the reconstructed image. + +For details, check the source code [here](https://github.com/Clay-foundation/model/blob/v0.0.1/src/model_clay.py). + +![Architecture](https://github.com/Clay-foundation/model/assets/23487320/c9b46255-c2d7-4ca4-a980-7ff3033c23e3) + +* Core Framework: [Lightning](https://lightning.ai/) and its dependencies, like PyTorch, etc. + +* Input modalities: + * Fixed spec of 10 bands of Sentinel-2, 2 bands of Sentinel-1 & 1 band of DEM data. See below for details. +* Output modalities: + * As a masked auto-enconder, fixed spec of 10 bands of Sentinel-2, 2 bands of Sentinel-1 & 1 band of DEM data, to mimic the input as close as possible. +* Model size: + * Number of parameters: `127M` + * Model size on disk: `~500MB`. +* Model license: + * Source code: [Apache 2.0](https://github.com/Clay-foundation/model/blob/v0.0.1/LICENSE) + * Model weights: [OpenRAIL-M](https://github.com/Clay-foundation/model/blob/v0.0.1/LICENSE-MODEL.md) + * Prohibited uses: See OpenRAIL-M license section 5. +* Feedback and redress mechanisms: + * Please open an issue or discussion on the [GitHub repository](https://github.com/Clay-foundation/model) or send an email to `bruno@madewithclay.org`. + +## Model Card + +For v0 of CLAY, we used the [`clay_small`](https://github.com/Clay-foundation/model/blob/0145e55bcf6bd3e9b19f5c07819a1398b6a22c35/src/model_clay.py#L713) setup model. + +``` +MASKED PATCHES = 75% +INPUT SIZE = 13 bands x 512 width x 512 height +PATCH SIZE = 32 x 32 + +OPTIMIZER + Adam + Learning rate = 1e-4 + Weight decay = 0.05 + Beta 1 = 0.9 + Beta 2 = 0.95 + +SCHEDULER + CosineAnnealingWarmRestarts + T_0 = 1000 + T_mult = 2 + eta_min = Learning rate * 10 + +ENCODER + dim = 768 + depth = 12 + heads = 12 + dim_head = 64 + mlp_ratio = 4 + dropout = 0.0 + emb_dropout = 0.0 + +DECODER + decoder_dim = 512 + decoder_depth = 8 + decoder_heads = 8 + decoder_dim_head = 64 + decoder_mlp_ratio = 4 + decoder_dropout = 0.0 +``` + +(Data_card)= +## Data Card + +We organize our input dataset creation in MGRS tiles. Each tile is a 10km x 10km area. We have `1203` tiles in total, each with 3 timesteps of data between `2017` and `2023`, so `3609 Tiles` in total. Each timestep is a stack of 10 bands of Sentinel-2, 2 bands of Sentinel-1 & 1 band of DEM data. Each tile is split into `512 x 512` chips, so we have around `~1.2 Million` chips in total. Each chip contains `13 bands`, 10 of which are the Sentinel-2 bands, 2 are Sentinel 1 bands & 1 DEM band. We store each chip as geotiff, along with their coordinate & timestamp information that is used for model training. + +![Tile locations](https://github.com/Clay-foundation/model/assets/23487320/af46a272-a102-4c66-a8bc-52bcb987c365) + +* Training dataset size: `6.4 TB` +* Training dataset source links: + * [Sentinel-2](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a) + * [Sentinel-1](https://planetarycomputer.microsoft.com/dataset/sentinel-1-rtc) + * DEM from [Copernicus Digital Elevation Model](https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-90) +* Training dataset items: + * The actual list of files used is available [here](https://gist.github.com/brunosan/62247e5dc79684bdaca11cefae679e90). +* Data source selection and curation process: + * We aim for fully open data, with global and historical coverage, with the highest spatial, temporal and spectral resolution, hosted on a cloud format that eases the process to search and download the needed sections. + * Once these sources are selected, we make a [statistical sample based on cover type](https://github.com/Clay-foundation/model/blob/0145e55bcf6bd3e9b19f5c07819a1398b6a22c35/scripts/landcover.py#L156), so that we have a good coverage of the different landscapes. The land cover data is from [ESA WorldCover 2021](https://registry.opendata.aws/esa-worldcover-vito/). +* Data augmentation: + * We do not use any data augmentation techniques like affine transformations, random crops (except the masked autoencoder task), etc. We also do not use input mixing like CutMix, MixUp, etc. + * Clouds, cloud shadows, smog, atmospheric scattering, mid-air planes and other non-ground registrations could be considered natural augmentations. We explicitly filter out large % of clouds on our chips, but small clouds and their shadows might be present. As we increase the number of observations per location, and bands, we expect the model to learn to ignore single events but register patterns (places that are often cloudy or with smog). +* PII or harmful content: + * We believe that satellites images at this resolution (`10m/px`) are not subject to PII or harmful content concerns. +* Human evaluation, wages, and annotation process: + * Besides tweaking the statistical samples as part of the model development team, and the stated dataset hosting partners, we do not use any human evaluation, or annotation process, or third party services. + +We store each chip as geotiff, along with their coordinate & timestamp information that is used for model training. + +![bands](https://github.com/Clay-foundation/model/assets/23487320/85fbc8d2-28f6-4021-855b-c1eb84dd09e3) + + +## Training Card + +* Compute Resources: + * AWS EC2 `g5.12xlarge` with 4 NVIDIA A10G GPUs +* Batch Size: + * Batch Size = `10` + * Effective Batch Size = Batch Size x Number of GPUs x Gradient Accumulation Steps = `10` x `4` x `5` = `200` +* Training Time: + * `25` epochs, each taking ~`15h` to train. +* Carbon Emissions: + * *Report not yet available from provider, expected March'24* +* Training stages: + * While developing the model we run small tests locally and on the cloud. We estimate that all testing and development compute is less than the compute used for 1 epoch of training. + * QA of the model is also done locally and on the cloud, and we estimate that it is less than the compute used for 1 epoch of training. +* Release and distribution: + * Model development happens in an open source repository on GitHub [here](https://github.com/Clay-foundation/model/). + * We release the model weights on HuggingFace [here](https://huggingface.co/made-with-clay/Clay/). + * We release the embeddings on Source Cooperative [here](https://beta.source.coop/clay/). + * We do not have other distribution channels at this time. +* Production use: + * We support our partners to build applications with the model, and we expect them to use the model in production. + * We are developing a web application and expect to release it in 2024 Q1. + + +![Learning Rate & Epoch](https://github.com/Clay-foundation/model/assets/23487320/d2a2944c-0b2c-4c19-893b-abe3fca10edc) + +![MSE Loss for Pixel Reconstruction](https://github.com/Clay-foundation/model/assets/23487320/cbbed1d1-ca7b-4352-8a2a-610b33f42d1c) + +## Results + +As a foundational model, it is designed to be used as a building block for other models. In this section we only a sample of the training objective, which is to reconstruct the original image from a 75% masked image. + +[Reconstruction](https://github.com/Clay-foundation/model/assets/23487320/491febc1-af3c-43ab-bd9a-85ef7fbf6064) + + +### Performance Metrics +The model shows the following performance characteristics for its Masked Autoencoder objective: +* Training loss: `0.52` +* Validation loss: `0.46` + +## Known Limitations and Biases + +- The model is trained on Sentinel data only. +- Sentinel data only covers land and coastal waters. +- We only train on a ver small sample of the Sentinel archives, both in terms of spatial coverage and time. +- We do not train on the poles, and we do not train on open ocean, nor ocean nor atmospheric volumetric data. +- We do not train on night time data. +- We do not explicitly include extreme events in the training data. +- We only train at most 3 different times per location. + + +## Ethical Considerations + +Our goal is to lower the barrier to use EO data for biodiversity and climate change mitigation and adaptation. We have designed our model to support this goal. + +We have also designed our model to be as open as possible, as modular as possible, as undifferentiated and general as possible, and as well documented as possible, so we can maximize the leverage of the resources needed for the creation of this model. + +As a fully open model, we cannot however control how it is used. We are aware that EO data can be used for harmful purposes, and we are committed to work with our partners to prevent this from happening. diff --git a/src/model_clay.py b/src/model_clay.py index b3a1dec9..e89fef36 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -56,6 +56,7 @@ def __init__( # noqa: PLR0913 mask_ratio, image_size, patch_size, + shuffle, dim, depth, heads, @@ -73,6 +74,7 @@ def __init__( # noqa: PLR0913 self.mask_ratio = mask_ratio self.image_size = image_size self.patch_size = patch_size + self.shuffle = shuffle self.dim = dim self.bands = bands self.band_groups = band_groups @@ -244,7 +246,13 @@ def mask_out(self, patches): GL == self.num_patches ), f"Expected {self.num_patches} patches, got {GL} patches." - noise = torch.randn((B, GL), device=patches.device) # [B GL] + if self.shuffle: # Shuffle the patches + noise = torch.randn((B, GL), device=patches.device) # [B GL] + else: # Don't shuffle useful for interpolation & inspection of embeddings + noise = rearrange( + torch.arange(B * GL, device=patches.device), "(B GL) -> B GL", B=B + ) + random_indices = torch.argsort(noise, dim=-1) # [B GL] reverse_indices = torch.argsort(random_indices, dim=-1) # [B GL] @@ -581,12 +589,14 @@ def __init__( # noqa: PLR0913 "sar": (10, 11), "dem": (12,), }, + shuffle=True, **kwargs, ): super().__init__() self.mask_ratio = mask_ratio self.image_size = image_size self.patch_size = patch_size + self.shuffle = shuffle self.bands = bands self.band_groups = band_groups @@ -594,6 +604,7 @@ def __init__( # noqa: PLR0913 mask_ratio=mask_ratio, image_size=image_size, patch_size=patch_size, + shuffle=shuffle, dim=dim, depth=depth, heads=heads,