Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worldcover embeddings conus #153

Merged
merged 25 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ parts:
file: model_embeddings
- title: Finetuning
file: model_finetuning
- title: Embeddings for Contiguous US
file: worldcover-embeddings
- caption: Tutorials
chapters:
- title: Generative AI for pixel reconstruction
Expand Down
96 changes: 96 additions & 0 deletions docs/worldcover-embeddings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Running embeddings for Worldcover Sentinel-2 Composites
This package is made to generate embeddings from the [ESA Worldcover](https://esa-worldcover.org/en/data-access)
Sentinel-2 annual composites. The target region is all of the
Contiguous United States.

We ran this script for 2020 and 2021.

## The algorithm

The `run.py` script will run through a column of image chips of 512x512 pixels.
Each run is a column that spans the Contiguous United States from north to
south. For each chip in that column, embeddings are generated and stored
together in one geoparquet file. These files are then uploaded to the
`clay-worldcover-embeddings` bucket on S3.

There are 1359 such columns to process in order to cover all of the Conus US.

The embeddings are stored alongside with the bbox of the data chip used for
generating the embedding. To visualize the underlying data or an embedding
the WMS and WMTS endpoints provided by the ESA Worldcover project can be used.

So the geoparquet files only have the following two columns

| embeddings | bbox |
|------------------|--------------|
| [0.1, 0.4, ... ] | POLYGON(...) |
| [0.2, 0.5, ... ] | POLYGON(...) |
| [0.3, 0.6, ... ] | POLYGON(...) |

## Exploring results

The `embeddings_db.py` script provides a way to locally explore the embeddings.
It will create a `lancedb` database and allow for search. The search results are
visualizded by requesting the RGB image from the WMS endpoint for the bbox of
each search result.

## Running on Batch

### Upload package to fetch and run bucket
This snippet will create the zip package that is used for the fetch-and-run
instance in our ECR registry.

```bash
# Add clay src and scripts to zip file
zip -FSr batch-fetch-and-run-wc.zip src scripts -x *.pyc -x scripts/worldcover/wandb/**\*

# Add run to home dir, so that fetch-and-run can see it.
zip -uj batch-fetch-and-run-wc.zip scripts/worldcover/run.py

# Upload fetch-and-run package to S3
aws s3api put-object --bucket clay-fetch-and-run-packages --key "batch-fetch-and-run-wc.zip" --body "batch-fetch-and-run-wc.zip"
```

### Push array job
This command will send the array job to AWS batch to run all of the
1359 jobs to cover the US.

```python
import boto3

batch = boto3.client("batch", region_name="us-east-1")
year = 2020
job = {
"jobName": f"worldcover-conus-{year}",
"jobQueue": "fetch-and-run",
"jobDefinition": "fetch-and-run",
"containerOverrides": {
"command": ["run.py"],
"environment": [
{"name": "BATCH_FILE_TYPE", "value": "zip"},
{
"name": "BATCH_FILE_S3_URL",
"value": "s3://clay-fetch-and-run-packages/batch-fetch-and-run-wc.zip",
},
{"name": "YEAR", "value": f"{year}"}
],
"resourceRequirements": [
{"type": "MEMORY", "value": "7500"},
{"type": "VCPU", "value": "4"},
# {"type": "GPU", "value": "1"},
],
},
"arrayProperties": {
"size": int((125 - 67) * 12000 / 512)
},
"retryStrategy": {
"attempts": 5,
"evaluateOnExit": [
{"onStatusReason": "Host EC2*", "action": "RETRY"},
{"onReason": "*", "action": "EXIT"}
]
},
}

print(batch.submit_job(**job))
```
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[per-file-ignores]
"docs/clay_over_aoi.ipynb" = ["E501"]
"scripts/worldcover/worldcover_vrt.py" = ["E501"]

[format]
# https://docs.astral.sh/ruff/settings/#format
Expand Down
58 changes: 58 additions & 0 deletions scripts/worldcover/embeddings_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

import geopandas as gpd
import lancedb
import matplotlib.pyplot as plt
from skimage import io

# Set working directory
wd = "/home/usr/Desktop/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yellowcap, I know this is already merged, but can you avoid such absolute/hardcoded paths?

Copy link
Member Author

@yellowcap yellowcap Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feeback @chuckwondo you are right, this isn't great, but is supposed to be a placeholder. Sometimes I use fake paths like /path/to/your/working/directory, to show what this is supposed to be so that people running the script could replace it.

But I am very happy to learn about better ways to do this, what is your favorite solution for this kind of thing? In cases like this with scripts, maybe env vars could be an option?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to propose some ideas, but I need some context first. How is this script intended to be used? Is the intent to have the user first run the aws s3 sync command shown in the code comment below this line, and then just directly call this script?

Copy link
Member Author

@yellowcap yellowcap Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, the idea is that someone with access to the embeddings downloads them using aws s3 sync to a local folder, and then runs the script pointing to the embedding files and to a folder where the lancedb data should be stored.

I.e. the script needs two folders

  1. A place to download the raw data
  2. A folder to create / store the lancedb data

I made many scripts like this where some local workding directories are needed. Never really found a very satisfactory way of handling this. Env vars seem a bit cluncky and are not always easy to set up. Constants work, but then the script requires a hard coded default value.

So if you have good ideas on how to approach this issue, they are most welcome!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add command-line arguments. For very simple scripts like this one, simply use Python's standard argparse module, so you don't have to add any dependencies. In this case, it sounds like you might want to use argparse to support a syntax like the following for running the script:

scripts/worldcover/embeddings_db.py --input-dir path/to/input/dir --db-dir path/to/db/dir

Both options should be required, and the script should also create the dir specified for --db-dir, so the user doesn't have to do so manually beforehand.


# To download the existing embeddings run aws s3 sync
# aws s3 sync s3://clay-worldcover-embeddings /my/dir/clay-worldcover-embeddings

vector_dir = Path(wd + "clay-worldcover-embeddings/v002/2021/")

# Create new DB structure or open existing
db = lancedb.connect(wd + "worldcoverembeddings_db")

# Read all vector embeddings into a list
data = []
for strip in vector_dir.glob("*.gpq"):
print(strip)
tile_df = gpd.read_parquet(strip).to_crs("epsg:3857")

for _, row in tile_df.iterrows():
data.append(
{"vector": row["embeddings"], "year": 2021, "bbox": row.geometry.bounds}
)

# Show table names
db.table_names()

# Drop existing table if exists
db.drop_table("worldcover-2021-v001")

# Create embeddings table and insert the vector data
tbl = db.create_table("worldcover-2021-v001", data=data, mode="overwrite")


# Visualize some image chips
def plot(df, cols=10):
fig, axs = plt.subplots(1, cols, figsize=(20, 10))

for ax, (i, row) in zip(axs.flatten(), df.iterrows()):
bbox = row["bbox"]
url = f"https://services.terrascope.be/wms/v2?SERVICE=WMS&version=1.1.1&REQUEST=GetMap&layers=WORLDCOVER_2021_S2_TCC&BBOX={','.join([str(dat) for dat in bbox])}&SRS=EPSG:3857&FORMAT=image/png&WIDTH=512&HEIGHT=512" # noqa: E501
image = io.imread(url)
ax.imshow(image)
ax.set_axis_off()

plt.tight_layout()
plt.show()


# Select a vector by index, and search 10 similar pairs, and plot
v = tbl.to_pandas()["vector"].values[10540]
result = tbl.search(query=v).limit(5).to_pandas()
plot(result, 5)
Loading