diff --git a/README.md b/README.md index f0e31382..1e2d52d1 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,6 @@ An open source AI model and interface for Earth. -# Getting started - ## Quickstart Launch into a [JupyterLab](https://jupyterlab.readthedocs.io) environment on @@ -74,3 +72,26 @@ To train the model: More options can be found using `python trainer.py fit --help`, or at the [LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). + +## Contributing + +### Writing documentation + +Our Documentation uses [Jupyter Book](https://jupyterbook.org/intro.html). + +Install it with: +```bash +pip install -U jupyter-book +``` + +Then build it with: +```bash +jupyter-book build docs/ +``` + +You can preview the site locally with: +```bash +python -m http.server --directory _build/html +``` + +There is a GitHub Action on `./github/workflows/deploy-docs.yml` that builds the site and pushes it to GitHub Pages. diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index f5bc0086..00000000 --- a/docs/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# Clay Model Documentation - -This Documentation uses [Jupyter Book](https://jupyterbook.org/intro.html). - -Install it with: -```bash -pip install -U jupyter-book -``` - -Then build it with: -```bash -jupyter-book build docs/ -``` - -You can preview the site locally with: -```bash -python -m http.server --directory _build/html -``` - -There is a GitHub Action on `./github/workflows/deploy-docs.yml` that builds the site and pushes it to GitHub Pages. diff --git a/docs/_toc.yml b/docs/_toc.yml index 276d32b4..3b3b726b 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -7,25 +7,25 @@ parts: - caption: Release notes chapters: - title: Software release notes - file: changelog + file: release-notes/changelog - title: Model release notes - file: specification + file: release-notes/specification - title: Data sampling strategy - file: data_sampling + file: release-notes/data_sampling - caption: Getting Started chapters: - title: Installation - file: installation + file: getting-started/installation - title: Basic Use - file: basic_use + file: getting-started/basic_use - caption: Tutorials chapters: - title: Clay v1 wall-to-wall example - file: clay-v1-wall-to-wall + file: tutorials/clay-v1-wall-to-wall - title: Explore embeddings from Clay Encoder - file: visualize-embeddings + file: tutorials/visualize-embeddings - title: Clay MAE reconstruction - file: reconstruction + file: tutorials/reconstruction - caption: About Clay chapters: - title: GitHub diff --git a/docs/clay-v1-wall-to-wall.ipynb b/docs/clay-v1-wall-to-wall.ipynb deleted file mode 100644 index 4dcac9b3..00000000 --- a/docs/clay-v1-wall-to-wall.ipynb +++ /dev/null @@ -1,1648 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0cc5e729-9116-4ec9-bf1e-8346cbccdf7b", - "metadata": {}, - "source": [ - "## Run Clay v1\n", - "\n", - "This notebook shows how to run Clay v1 wall-to-wall, from downloading imagery\n", - "to training a tiny fine tuning head. This will include the following steps:\n", - "\n", - "1. Set a location and date range of interest\n", - "2. Download Sentinel-2 imagery for this specification\n", - "3. Load the model checkpoint\n", - "4. Prepare data into a format for the model\n", - "5. Run the model on the imagery\n", - "6. Analyise the model embeddings output using PCA\n", - "7. Train a Support Vector Machines fine tuning head" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "add63cd9", - "metadata": {}, - "outputs": [], - "source": [ - "# Add the repo root to the sys path for the model import below\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "6a17b8a8-a9c6-4053-833e-de97287fae49", - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "\n", - "import geopandas as gpd\n", - "import numpy as np\n", - "import pandas as pd\n", - "import pystac_client\n", - "import stackstac\n", - "import torch\n", - "import yaml\n", - "from box import Box\n", - "from matplotlib import pyplot as plt\n", - "from rasterio.enums import Resampling\n", - "from shapely import Point\n", - "from sklearn import decomposition, svm\n", - "from stacchip.processors.prechip import normalize_timestamp\n", - "from torchvision.transforms import v2\n", - "\n", - "from src.model import ClayMAEModule" - ] - }, - { - "cell_type": "markdown", - "id": "beac6394-9762-422b-9f5d-82d226018c0c", - "metadata": {}, - "source": [ - "### Specify location and date of interest\n", - "In this example we will use a location in Portugal where a forest fire happened. We will run the model over the time period of the fire and analyse the model embeddings." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "08d7787d-1506-4de7-89dc-c1054910acf7", - "metadata": {}, - "outputs": [], - "source": [ - "# Point over Monchique Portugal\n", - "lat, lon = 37.30939, -8.57207\n", - "\n", - "# Dates of a large forest fire\n", - "start = \"2018-07-01\"\n", - "end = \"2018-09-01\"" - ] - }, - { - "cell_type": "markdown", - "id": "2bd226c9-003b-4867-a64a-8ae887e7e20a", - "metadata": {}, - "source": [ - "### Get data from STAC catalog\n", - "\n", - "Based on the location and date we can obtain a stack of imagery using stackstac. Let's start with finding the STAC items we want to analyse." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "2e80743c-7c77-459b-9984-f6c26cdff549", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/tam/apps/miniforge3/envs/claymodel/lib/python3.11/site-packages/pystac_client/item_search.py:850: FutureWarning: get_all_items() is deprecated, use item_collection() instead.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 12 items\n" - ] - } - ], - "source": [ - "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", - "COLLECTION = \"sentinel-2-l2a\"\n", - "\n", - "# Search the catalogue\n", - "catalog = pystac_client.Client.open(STAC_API)\n", - "search = catalog.search(\n", - " collections=[COLLECTION],\n", - " datetime=f\"{start}/{end}\",\n", - " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", - " max_items=100,\n", - " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", - ")\n", - "\n", - "all_items = search.get_all_items()\n", - "\n", - "# Reduce to one per date (there might be some duplicates\n", - "# based on the location)\n", - "items = []\n", - "dates = []\n", - "for item in all_items:\n", - " if item.datetime.date() not in dates:\n", - " items.append(item)\n", - " dates.append(item.datetime.date())\n", - "\n", - "print(f\"Found {len(items)} items\")" - ] - }, - { - "cell_type": "markdown", - "id": "5b7c68ae-7c8a-446a-8bc7-5afba70183c2", - "metadata": {}, - "source": [ - "### Create a bounding box around the point of interest\n", - "\n", - "This is needed in the projection of the data so that we can generate image chips of the right size." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0f3573b5-5a00-47d9-a648-5c4d7cd2c996", - "metadata": {}, - "outputs": [], - "source": [ - "# Extract coordinate system from first item\n", - "epsg = items[0].properties[\"proj:epsg\"]\n", - "\n", - "# Convert point of interest into the image projection\n", - "# (assumes all images are in the same projection)\n", - "poidf = gpd.GeoDataFrame(\n", - " pd.DataFrame(),\n", - " crs=\"EPSG:4326\",\n", - " geometry=[Point(lon, lat)],\n", - ").to_crs(epsg)\n", - "\n", - "coords = poidf.iloc[0].geometry.coords[0]\n", - "\n", - "# Create bounds in projection\n", - "size = 256\n", - "gsd = 10\n", - "bounds = (\n", - " coords[0] - (size * gsd) // 2,\n", - " coords[1] - (size * gsd) // 2,\n", - " coords[0] + (size * gsd) // 2,\n", - " coords[1] + (size * gsd) // 2,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "bbbd3f67-5f2c-46dc-9ee1-2ef1f50fa032", - "metadata": {}, - "source": [ - "### Retrieve the imagery data." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8b8d3824-e48c-4f9d-9c7b-181c0800f96f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Working with stack of size (12, 4, 256, 256)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray 'stackstac-7cbad7c129d678be53c9b6676bee564b' (time: 12,\n",
-       "                                                                band: 4,\n",
-       "                                                                y: 256, x: 256)> Size: 13MB\n",
-       "array([[[[ 9136.,  9232.,  9544., ...,  1258.,  1120.,   930.],\n",
-       "         [ 9616.,  9768.,  9840., ...,  1230.,  1208.,  1030.],\n",
-       "         [ 9992., 10008., 10000., ...,  1418.,  1336.,  1242.],\n",
-       "         ...,\n",
-       "         [  811.,   655.,   688., ...,   385.,   362.,   461.],\n",
-       "         [  798.,   675.,   727., ...,   394.,   415.,   402.],\n",
-       "         [  888.,   673.,   642., ...,   403.,   454.,   393.]],\n",
-       "\n",
-       "        [[ 8656.,  8656.,  8864., ...,  1500.,  1428.,  1220.],\n",
-       "         [ 9016.,  9160.,  9224., ...,  1546.,  1522.,  1360.],\n",
-       "         [ 9248.,  9328.,  9384., ...,  1620.,  1542.,  1482.],\n",
-       "         ...,\n",
-       "         [ 1010.,   831.,   853., ...,   277.,   276.,   336.],\n",
-       "         [ 1016.,   930.,   927., ...,   276.,   317.,   293.],\n",
-       "         [ 1112.,   885.,   827., ...,   299.,   369.,   293.]],\n",
-       "\n",
-       "        [[ 8416.,  8416.,  8640., ...,  1598.,  1466.,  1138.],\n",
-       "         [ 8744.,  8880.,  8928., ...,  1498.,  1522.,  1284.],\n",
-       "         [ 8952.,  8944.,  8960., ...,  1542.,  1478.,  1448.],\n",
-       "         ...,\n",
-       "...\n",
-       "         [  652.,   640.,   638., ...,   590.,   821.,  1008.],\n",
-       "         [  622.,   676.,   630., ...,   606.,  1092.,   726.],\n",
-       "         [  864.,   786.,   569., ...,   766.,  1068.,   630.]],\n",
-       "\n",
-       "        [[  201.,   213.,   195., ...,  1138.,  1058.,   749.],\n",
-       "         [  196.,   198.,   169., ...,   861.,   784.,   768.],\n",
-       "         [  216.,   178.,   191., ...,   870.,   806.,   820.],\n",
-       "         ...,\n",
-       "         [  857.,   838.,   846., ...,   622.,   800.,  1332.],\n",
-       "         [  922.,   848.,   771., ...,   786.,  1046.,   912.],\n",
-       "         [ 1118.,  1010.,   735., ...,   755.,   977.,   686.]],\n",
-       "\n",
-       "        [[ 3264.,  3352.,  3304., ...,  3160.,  3296.,  3376.],\n",
-       "         [ 3356.,  3300.,  3212., ...,  3188.,  3272.,  3064.],\n",
-       "         [ 3288.,  3372.,  3344., ...,  3136.,  3200.,  2932.],\n",
-       "         ...,\n",
-       "         [ 1320.,  1468.,  1298., ...,  2492.,  2556.,  3018.],\n",
-       "         [ 1630.,  1694.,  1250., ...,  2318.,  2684.,  2894.],\n",
-       "         [ 2190.,  2072.,  1288., ...,  2544.,  2942.,  2928.]]]],\n",
-       "      dtype=float32)\n",
-       "Coordinates: (12/53)\n",
-       "  * time                                     (time) datetime64[ns] 96B 2018-0...\n",
-       "    id                                       (time) <U24 1kB 'S2B_29SNB_20180...\n",
-       "  * band                                     (band) <U5 80B 'blue' ... 'nir'\n",
-       "  * x                                        (x) float64 2kB 5.366e+05 ... 5....\n",
-       "  * y                                        (y) float64 2kB 4.131e+06 ... 4....\n",
-       "    platform                                 (time) <U11 528B 'sentinel-2b' ....\n",
-       "    ...                                       ...\n",
-       "    gsd                                      int64 8B 10\n",
-       "    title                                    (band) <U20 320B 'Blue (band 2) ...\n",
-       "    common_name                              (band) <U5 80B 'blue' ... 'nir'\n",
-       "    center_wavelength                        (band) float64 32B 0.49 ... 0.842\n",
-       "    full_width_half_max                      (band) float64 32B 0.098 ... 0.145\n",
-       "    epsg                                     int64 8B 32629\n",
-       "Attributes:\n",
-       "    spec:        RasterSpec(epsg=32629, bounds=(536640.79691545, 4128000.7407...\n",
-       "    crs:         epsg:32629\n",
-       "    transform:   | 10.00, 0.00, 536640.80|\\n| 0.00,-10.00, 4130560.74|\\n| 0.0...\n",
-       "    resolution:  10
" - ], - "text/plain": [ - " Size: 13MB\n", - "array([[[[ 9136., 9232., 9544., ..., 1258., 1120., 930.],\n", - " [ 9616., 9768., 9840., ..., 1230., 1208., 1030.],\n", - " [ 9992., 10008., 10000., ..., 1418., 1336., 1242.],\n", - " ...,\n", - " [ 811., 655., 688., ..., 385., 362., 461.],\n", - " [ 798., 675., 727., ..., 394., 415., 402.],\n", - " [ 888., 673., 642., ..., 403., 454., 393.]],\n", - "\n", - " [[ 8656., 8656., 8864., ..., 1500., 1428., 1220.],\n", - " [ 9016., 9160., 9224., ..., 1546., 1522., 1360.],\n", - " [ 9248., 9328., 9384., ..., 1620., 1542., 1482.],\n", - " ...,\n", - " [ 1010., 831., 853., ..., 277., 276., 336.],\n", - " [ 1016., 930., 927., ..., 276., 317., 293.],\n", - " [ 1112., 885., 827., ..., 299., 369., 293.]],\n", - "\n", - " [[ 8416., 8416., 8640., ..., 1598., 1466., 1138.],\n", - " [ 8744., 8880., 8928., ..., 1498., 1522., 1284.],\n", - " [ 8952., 8944., 8960., ..., 1542., 1478., 1448.],\n", - " ...,\n", - "...\n", - " [ 652., 640., 638., ..., 590., 821., 1008.],\n", - " [ 622., 676., 630., ..., 606., 1092., 726.],\n", - " [ 864., 786., 569., ..., 766., 1068., 630.]],\n", - "\n", - " [[ 201., 213., 195., ..., 1138., 1058., 749.],\n", - " [ 196., 198., 169., ..., 861., 784., 768.],\n", - " [ 216., 178., 191., ..., 870., 806., 820.],\n", - " ...,\n", - " [ 857., 838., 846., ..., 622., 800., 1332.],\n", - " [ 922., 848., 771., ..., 786., 1046., 912.],\n", - " [ 1118., 1010., 735., ..., 755., 977., 686.]],\n", - "\n", - " [[ 3264., 3352., 3304., ..., 3160., 3296., 3376.],\n", - " [ 3356., 3300., 3212., ..., 3188., 3272., 3064.],\n", - " [ 3288., 3372., 3344., ..., 3136., 3200., 2932.],\n", - " ...,\n", - " [ 1320., 1468., 1298., ..., 2492., 2556., 3018.],\n", - " [ 1630., 1694., 1250., ..., 2318., 2684., 2894.],\n", - " [ 2190., 2072., 1288., ..., 2544., 2942., 2928.]]]],\n", - " dtype=float32)\n", - "Coordinates: (12/53)\n", - " * time (time) datetime64[ns] 96B 2018-0...\n", - " id (time) " - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Run PCA\n", - "pca = decomposition.PCA(n_components=1)\n", - "pca_result = pca.fit_transform(embeddings)\n", - "\n", - "plt.xticks(rotation=-45)\n", - "\n", - "# Plot all points in blue first\n", - "plt.scatter(stack.time, pca_result, color=\"blue\")\n", - "\n", - "# Re-plot cloudy images in green\n", - "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", - "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", - "\n", - "# Color all images after fire in red\n", - "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" - ] - }, - { - "cell_type": "markdown", - "id": "b38b70a6-2156-41f8-967e-a490cc8e2778", - "metadata": {}, - "source": [ - "### And finally, some finetuning\n", - "\n", - "We are going to train a classifier head on the embeddings and use it to detect fires." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "1da07de0-b8f2-46c9-bd2a-58b15ca2224f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Matched 5 out of 5 correctly\n" - ] - } - ], - "source": [ - "# Label the images we downloaded\n", - "# 0 = Cloud\n", - "# 1 = Forest\n", - "# 2 = Fire\n", - "labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])\n", - "\n", - "# Split into fit and test manually, ensuring we have all 3 classes in both sets\n", - "fit = [0, 1, 3, 4, 7, 8, 9]\n", - "test = [2, 5, 6, 10, 11]\n", - "\n", - "# Train a support vector machine model\n", - "clf = svm.SVC()\n", - "clf.fit(embeddings[fit] + 100, labels[fit])\n", - "\n", - "# Predict classes on test set\n", - "prediction = clf.predict(embeddings[test] + 100)\n", - "\n", - "# Perfect match for SVM\n", - "match = np.sum(labels[test] == prediction)\n", - "print(f\"Matched {match} out of {len(test)} correctly\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "claymodel", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/basic_use.md b/docs/getting-started/basic_use.md similarity index 50% rename from docs/basic_use.md rename to docs/getting-started/basic_use.md index 79c1ab66..83308511 100644 --- a/docs/basic_use.md +++ b/docs/getting-started/basic_use.md @@ -1,6 +1,6 @@ # Basic Use -### Running jupyter lab +## Running jupyter lab mamba activate claymodel python -m ipykernel install --user --name claymodel # to install virtual env properly @@ -8,35 +8,21 @@ jupyter lab & -### Running the model - +## Running the model The neural network model can be ran via [LightningCLI v2](https://pytorch-lightning.medium.com/introducing-lightningcli-v2supercharge-your-training-c070d43c7dd6). To check out the different options available, and look at the hyperparameter configurations, run: python trainer.py --help - python trainer.py test --print_config To quickly test the model on one batch in the validation set: - python trainer.py validate --trainer.fast_dev_run=True - -To train the model for a hundred epochs: - - python trainer.py fit --trainer.max_epochs=100 + python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml --trainer.fast_dev_run=True -To generate embeddings from the pretrained model's encoder on 1024 images -(stored as a GeoParquet file with spatiotemporal metadata): +To train the model: - python trainer.py predict --ckpt_path=checkpoints/last.ckpt \ - --data.batch_size=1024 \ - --data.data_dir=s3://clay-tiles-02 \ - --trainer.limit_predict_batches=1 + python trainer.py fit --model ClayMAEModule --data ClayDataModule --config configs/config.yaml More options can be found using `python trainer.py fit --help`, or at the [LightningCLI docs](https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli.html). - -## Advanced - -See [Readme](https://github.com/Clay-foundation/model/blob/v0.0.1/README.md) on model root for more details. diff --git a/docs/installation.md b/docs/getting-started/installation.md similarity index 100% rename from docs/installation.md rename to docs/getting-started/installation.md diff --git a/docs/changelog.md b/docs/release-notes/changelog.md similarity index 100% rename from docs/changelog.md rename to docs/release-notes/changelog.md diff --git a/docs/data_sampling.md b/docs/release-notes/data_sampling.md similarity index 100% rename from docs/data_sampling.md rename to docs/release-notes/data_sampling.md diff --git a/docs/specification.md b/docs/release-notes/specification.md similarity index 97% rename from docs/specification.md rename to docs/release-notes/specification.md index 21e6e6b1..1b5f050d 100644 --- a/docs/specification.md +++ b/docs/release-notes/specification.md @@ -1,6 +1,6 @@ # Pretrained Model release v1.0 -This changelog is a summary of the changes to the pretrained model weights for the Clay model. We follow the "Stanford [Foundation Model Transparency Index](https://github.com/stanford-crfm/fmti)" +This changelog is a summary of the changes to the pretrained model weights for the Clay model. We follow the "[Stanford Foundation Model Transparency Index](https://github.com/stanford-crfm/fmti)" Model weights released on 2024/05/12. @@ -130,7 +130,7 @@ The data used for this model is described in detail in the [](training-data) sec ## Results -As a foundation model, it is designed to be used as a building block for other models. We have examples of what the embedding space & reconstruction looks like for the base model in the docs [here](visualize-embedding.ipynb) & [here](reconstruction.ipynb). +As a foundation model, it is designed to be used as a building block for other models. We have documented examples of how the [embedding space](../tutorials/visualize-embeddings.ipynb) and the [reconstructions](../tutorials/reconstruction.ipynb) look like for the base model. ### Performance Metrics diff --git a/docs/tutorials/clay-v1-wall-to-wall.ipynb b/docs/tutorials/clay-v1-wall-to-wall.ipynb new file mode 100644 index 00000000..378aeb43 --- /dev/null +++ b/docs/tutorials/clay-v1-wall-to-wall.ipynb @@ -0,0 +1,601 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0cc5e729-9116-4ec9-bf1e-8346cbccdf7b", + "metadata": {}, + "source": [ + "## Run Clay v1\n", + "\n", + "This notebook shows how to run Clay v1 wall-to-wall, from downloading imagery\n", + "to training a tiny fine tuning head. This will include the following steps:\n", + "\n", + "1. Set a location and date range of interest\n", + "2. Download Sentinel-2 imagery for this specification\n", + "3. Load the model checkpoint\n", + "4. Prepare data into a format for the model\n", + "5. Run the model on the imagery\n", + "6. Analyise the model embeddings output using PCA\n", + "7. Train a Support Vector Machines fine tuning head" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "add63cd9", + "metadata": {}, + "outputs": [], + "source": [ + "# Add the repo root to the sys path for the model import below\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6a17b8a8-a9c6-4053-833e-de97287fae49", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import geopandas as gpd\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pystac_client\n", + "import stackstac\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from matplotlib import pyplot as plt\n", + "from rasterio.enums import Resampling\n", + "from shapely import Point\n", + "from sklearn import decomposition, svm\n", + "from torchvision.transforms import v2\n", + "\n", + "from src.model import ClayMAEModule" + ] + }, + { + "cell_type": "markdown", + "id": "beac6394-9762-422b-9f5d-82d226018c0c", + "metadata": {}, + "source": [ + "### Specify location and date of interest\n", + "In this example we will use a location in Portugal where a forest fire happened. We will run the model over the time period of the fire and analyse the model embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "08d7787d-1506-4de7-89dc-c1054910acf7", + "metadata": {}, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "markdown", + "id": "2bd226c9-003b-4867-a64a-8ae887e7e20a", + "metadata": {}, + "source": [ + "### Get data from STAC catalog\n", + "\n", + "Based on the location and date we can obtain a stack of imagery using stackstac. Let's start with finding the STAC items we want to analyse." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2e80743c-7c77-459b-9984-f6c26cdff549", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tam/apps/miniforge3/envs/claymodel/lib/python3.11/site-packages/pystac_client/item_search.py:850: FutureWarning: get_all_items() is deprecated, use item_collection() instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 12 items\n" + ] + } + ], + "source": [ + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "5b7c68ae-7c8a-446a-8bc7-5afba70183c2", + "metadata": {}, + "source": [ + "### Create a bounding box around the point of interest\n", + "\n", + "This is needed in the projection of the data so that we can generate image chips of the right size." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0f3573b5-5a00-47d9-a648-5c4d7cd2c996", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract coordinate system from first item\n", + "epsg = items[0].properties[\"proj:epsg\"]\n", + "\n", + "# Convert point of interest into the image projection\n", + "# (assumes all images are in the same projection)\n", + "poidf = gpd.GeoDataFrame(\n", + " pd.DataFrame(),\n", + " crs=\"EPSG:4326\",\n", + " geometry=[Point(lon, lat)],\n", + ").to_crs(epsg)\n", + "\n", + "coords = poidf.iloc[0].geometry.coords[0]\n", + "\n", + "# Create bounds in projection\n", + "size = 256\n", + "gsd = 10\n", + "bounds = (\n", + " coords[0] - (size * gsd) // 2,\n", + " coords[1] - (size * gsd) // 2,\n", + " coords[0] + (size * gsd) // 2,\n", + " coords[1] + (size * gsd) // 2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bbbd3f67-5f2c-46dc-9ee1-2ef1f50fa032", + "metadata": {}, + "source": [ + "### Retrieve the imagery data." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8b8d3824-e48c-4f9d-9c7b-181c0800f96f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Size: 13MB\n", + "dask.array\n", + "Coordinates: (12/53)\n", + " * time (time) datetime64[ns] 96B 2018-0...\n", + " id (time) " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Run PCA\n", + "pca = decomposition.PCA(n_components=1)\n", + "pca_result = pca.fit_transform(embeddings)\n", + "\n", + "plt.xticks(rotation=-45)\n", + "\n", + "# Plot all points in blue first\n", + "plt.scatter(stack.time, pca_result, color=\"blue\")\n", + "\n", + "# Re-plot cloudy images in green\n", + "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", + "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", + "\n", + "# Color all images after fire in red\n", + "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" + ] + }, + { + "cell_type": "markdown", + "id": "b38b70a6-2156-41f8-967e-a490cc8e2778", + "metadata": {}, + "source": [ + "### And finally, some finetuning\n", + "\n", + "We are going to train a classifier head on the embeddings and use it to detect fires." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1da07de0-b8f2-46c9-bd2a-58b15ca2224f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Matched 5 out of 5 correctly\n" + ] + } + ], + "source": [ + "# Label the images we downloaded\n", + "# 0 = Cloud\n", + "# 1 = Forest\n", + "# 2 = Fire\n", + "labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])\n", + "\n", + "# Split into fit and test manually, ensuring we have all 3 classes in both sets\n", + "fit = [0, 1, 3, 4, 7, 8, 9]\n", + "test = [2, 5, 6, 10, 11]\n", + "\n", + "# Train a support vector machine model\n", + "clf = svm.SVC()\n", + "clf.fit(embeddings[fit] + 100, labels[fit])\n", + "\n", + "# Predict classes on test set\n", + "prediction = clf.predict(embeddings[test] + 100)\n", + "\n", + "# Perfect match for SVM\n", + "match = np.sum(labels[test] == prediction)\n", + "print(f\"Matched {match} out of {len(test)} correctly\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/reconstruction.ipynb b/docs/tutorials/reconstruction.ipynb similarity index 100% rename from docs/reconstruction.ipynb rename to docs/tutorials/reconstruction.ipynb diff --git a/docs/visualize-embeddings.ipynb b/docs/tutorials/visualize-embeddings.ipynb similarity index 100% rename from docs/visualize-embeddings.ipynb rename to docs/tutorials/visualize-embeddings.ipynb