Skip to content

Commit

Permalink
Worldcover embeddings conus (#153)
Browse files Browse the repository at this point in the history
* Add script to generate worldcover composite vrt files

Focus on CONUS area.

* Add initial version of batch run script

* Intermediate

* Improve print statements

* Reduce batch size and fix array index usage

* Disable workers on datamodule to save memory

* Add script to explore embeddings using lancedb

* Rename run.py file

* Index based run file

* Small fixes

* Add initial readme

* Full array size, change mem requirements

* Remove scripts from previous attempt

* Improved docs

* Use v002

* Improved docs

* Improved docs

* Improved docs

* Move worldcover readme into docs

* Make year a parameter

* Fix url formatting

* Fix url worldcover version by year

* Use S3 uri for model checkpoint
  • Loading branch information
yellowcap authored Feb 28, 2024
1 parent daa5ab0 commit 472c92f
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ parts:
file: model_embeddings
- title: Finetuning
file: model_finetuning
- title: Embeddings for Contiguous US
file: worldcover-embeddings
- caption: Tutorials
chapters:
- title: Generative AI for pixel reconstruction
Expand Down
96 changes: 96 additions & 0 deletions docs/worldcover-embeddings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Running embeddings for Worldcover Sentinel-2 Composites
This package is made to generate embeddings from the [ESA Worldcover](https://esa-worldcover.org/en/data-access)
Sentinel-2 annual composites. The target region is all of the
Contiguous United States.

We ran this script for 2020 and 2021.

## The algorithm

The `run.py` script will run through a column of image chips of 512x512 pixels.
Each run is a column that spans the Contiguous United States from north to
south. For each chip in that column, embeddings are generated and stored
together in one geoparquet file. These files are then uploaded to the
`clay-worldcover-embeddings` bucket on S3.

There are 1359 such columns to process in order to cover all of the Conus US.

The embeddings are stored alongside with the bbox of the data chip used for
generating the embedding. To visualize the underlying data or an embedding
the WMS and WMTS endpoints provided by the ESA Worldcover project can be used.

So the geoparquet files only have the following two columns

| embeddings | bbox |
|------------------|--------------|
| [0.1, 0.4, ... ] | POLYGON(...) |
| [0.2, 0.5, ... ] | POLYGON(...) |
| [0.3, 0.6, ... ] | POLYGON(...) |

## Exploring results

The `embeddings_db.py` script provides a way to locally explore the embeddings.
It will create a `lancedb` database and allow for search. The search results are
visualizded by requesting the RGB image from the WMS endpoint for the bbox of
each search result.

## Running on Batch

### Upload package to fetch and run bucket
This snippet will create the zip package that is used for the fetch-and-run
instance in our ECR registry.

```bash
# Add clay src and scripts to zip file
zip -FSr batch-fetch-and-run-wc.zip src scripts -x *.pyc -x scripts/worldcover/wandb/**\*

# Add run to home dir, so that fetch-and-run can see it.
zip -uj batch-fetch-and-run-wc.zip scripts/worldcover/run.py

# Upload fetch-and-run package to S3
aws s3api put-object --bucket clay-fetch-and-run-packages --key "batch-fetch-and-run-wc.zip" --body "batch-fetch-and-run-wc.zip"
```

### Push array job
This command will send the array job to AWS batch to run all of the
1359 jobs to cover the US.

```python
import boto3

batch = boto3.client("batch", region_name="us-east-1")
year = 2020
job = {
"jobName": f"worldcover-conus-{year}",
"jobQueue": "fetch-and-run",
"jobDefinition": "fetch-and-run",
"containerOverrides": {
"command": ["run.py"],
"environment": [
{"name": "BATCH_FILE_TYPE", "value": "zip"},
{
"name": "BATCH_FILE_S3_URL",
"value": "s3://clay-fetch-and-run-packages/batch-fetch-and-run-wc.zip",
},
{"name": "YEAR", "value": f"{year}"}
],
"resourceRequirements": [
{"type": "MEMORY", "value": "7500"},
{"type": "VCPU", "value": "4"},
# {"type": "GPU", "value": "1"},
],
},
"arrayProperties": {
"size": int((125 - 67) * 12000 / 512)
},
"retryStrategy": {
"attempts": 5,
"evaluateOnExit": [
{"onStatusReason": "Host EC2*", "action": "RETRY"},
{"onReason": "*", "action": "EXIT"}
]
},
}

print(batch.submit_job(**job))
```
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[per-file-ignores]
"docs/clay_over_aoi.ipynb" = ["E501"]
"scripts/worldcover/worldcover_vrt.py" = ["E501"]

[format]
# https://docs.astral.sh/ruff/settings/#format
Expand Down
58 changes: 58 additions & 0 deletions scripts/worldcover/embeddings_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

import geopandas as gpd
import lancedb
import matplotlib.pyplot as plt
from skimage import io

# Set working directory
wd = "/home/usr/Desktop/"

# To download the existing embeddings run aws s3 sync
# aws s3 sync s3://clay-worldcover-embeddings /my/dir/clay-worldcover-embeddings

vector_dir = Path(wd + "clay-worldcover-embeddings/v002/2021/")

# Create new DB structure or open existing
db = lancedb.connect(wd + "worldcoverembeddings_db")

# Read all vector embeddings into a list
data = []
for strip in vector_dir.glob("*.gpq"):
print(strip)
tile_df = gpd.read_parquet(strip).to_crs("epsg:3857")

for _, row in tile_df.iterrows():
data.append(
{"vector": row["embeddings"], "year": 2021, "bbox": row.geometry.bounds}
)

# Show table names
db.table_names()

# Drop existing table if exists
db.drop_table("worldcover-2021-v001")

# Create embeddings table and insert the vector data
tbl = db.create_table("worldcover-2021-v001", data=data, mode="overwrite")


# Visualize some image chips
def plot(df, cols=10):
fig, axs = plt.subplots(1, cols, figsize=(20, 10))

for ax, (i, row) in zip(axs.flatten(), df.iterrows()):
bbox = row["bbox"]
url = f"https://services.terrascope.be/wms/v2?SERVICE=WMS&version=1.1.1&REQUEST=GetMap&layers=WORLDCOVER_2021_S2_TCC&BBOX={','.join([str(dat) for dat in bbox])}&SRS=EPSG:3857&FORMAT=image/png&WIDTH=512&HEIGHT=512" # noqa: E501
image = io.imread(url)
ax.imshow(image)
ax.set_axis_off()

plt.tight_layout()
plt.show()


# Select a vector by index, and search 10 similar pairs, and plot
v = tbl.to_pandas()["vector"].values[10540]
result = tbl.search(query=v).limit(5).to_pandas()
plot(result, 5)
Loading

0 comments on commit 472c92f

Please sign in to comment.