From 472c92f264b0c25a6e981c576a407f763a578db7 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Wed, 28 Feb 2024 09:29:14 +0000 Subject: [PATCH] Worldcover embeddings conus (#153) * Add script to generate worldcover composite vrt files Focus on CONUS area. * Add initial version of batch run script * Intermediate * Improve print statements * Reduce batch size and fix array index usage * Disable workers on datamodule to save memory * Add script to explore embeddings using lancedb * Rename run.py file * Index based run file * Small fixes * Add initial readme * Full array size, change mem requirements * Remove scripts from previous attempt * Improved docs * Use v002 * Improved docs * Improved docs * Improved docs * Move worldcover readme into docs * Make year a parameter * Fix url formatting * Fix url worldcover version by year * Use S3 uri for model checkpoint --- docs/_toc.yml | 2 + docs/worldcover-embeddings.md | 96 +++++++++++ ruff.toml | 1 + scripts/worldcover/embeddings_db.py | 58 +++++++ scripts/worldcover/run.py | 251 ++++++++++++++++++++++++++++ 5 files changed, 408 insertions(+) create mode 100644 docs/worldcover-embeddings.md create mode 100644 scripts/worldcover/embeddings_db.py create mode 100755 scripts/worldcover/run.py diff --git a/docs/_toc.yml b/docs/_toc.yml index e05334c7..7a11719a 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -32,6 +32,8 @@ parts: file: model_embeddings - title: Finetuning file: model_finetuning + - title: Embeddings for Contiguous US + file: worldcover-embeddings - caption: Tutorials chapters: - title: Generative AI for pixel reconstruction diff --git a/docs/worldcover-embeddings.md b/docs/worldcover-embeddings.md new file mode 100644 index 00000000..798aae69 --- /dev/null +++ b/docs/worldcover-embeddings.md @@ -0,0 +1,96 @@ +# Running embeddings for Worldcover Sentinel-2 Composites +This package is made to generate embeddings from the [ESA Worldcover](https://esa-worldcover.org/en/data-access) +Sentinel-2 annual composites. The target region is all of the +Contiguous United States. + +We ran this script for 2020 and 2021. + +## The algorithm + +The `run.py` script will run through a column of image chips of 512x512 pixels. +Each run is a column that spans the Contiguous United States from north to +south. For each chip in that column, embeddings are generated and stored +together in one geoparquet file. These files are then uploaded to the +`clay-worldcover-embeddings` bucket on S3. + +There are 1359 such columns to process in order to cover all of the Conus US. + +The embeddings are stored alongside with the bbox of the data chip used for +generating the embedding. To visualize the underlying data or an embedding +the WMS and WMTS endpoints provided by the ESA Worldcover project can be used. + +So the geoparquet files only have the following two columns + +| embeddings | bbox | +|------------------|--------------| +| [0.1, 0.4, ... ] | POLYGON(...) | +| [0.2, 0.5, ... ] | POLYGON(...) | +| [0.3, 0.6, ... ] | POLYGON(...) | + +## Exploring results + +The `embeddings_db.py` script provides a way to locally explore the embeddings. +It will create a `lancedb` database and allow for search. The search results are +visualizded by requesting the RGB image from the WMS endpoint for the bbox of +each search result. + +## Running on Batch + +### Upload package to fetch and run bucket +This snippet will create the zip package that is used for the fetch-and-run +instance in our ECR registry. + +```bash +# Add clay src and scripts to zip file +zip -FSr batch-fetch-and-run-wc.zip src scripts -x *.pyc -x scripts/worldcover/wandb/**\* + +# Add run to home dir, so that fetch-and-run can see it. +zip -uj batch-fetch-and-run-wc.zip scripts/worldcover/run.py + +# Upload fetch-and-run package to S3 +aws s3api put-object --bucket clay-fetch-and-run-packages --key "batch-fetch-and-run-wc.zip" --body "batch-fetch-and-run-wc.zip" +``` + +### Push array job +This command will send the array job to AWS batch to run all of the +1359 jobs to cover the US. + +```python +import boto3 + +batch = boto3.client("batch", region_name="us-east-1") +year = 2020 +job = { + "jobName": f"worldcover-conus-{year}", + "jobQueue": "fetch-and-run", + "jobDefinition": "fetch-and-run", + "containerOverrides": { + "command": ["run.py"], + "environment": [ + {"name": "BATCH_FILE_TYPE", "value": "zip"}, + { + "name": "BATCH_FILE_S3_URL", + "value": "s3://clay-fetch-and-run-packages/batch-fetch-and-run-wc.zip", + }, + {"name": "YEAR", "value": f"{year}"} + ], + "resourceRequirements": [ + {"type": "MEMORY", "value": "7500"}, + {"type": "VCPU", "value": "4"}, + # {"type": "GPU", "value": "1"}, + ], + }, + "arrayProperties": { + "size": int((125 - 67) * 12000 / 512) + }, + "retryStrategy": { + "attempts": 5, + "evaluateOnExit": [ + {"onStatusReason": "Host EC2*", "action": "RETRY"}, + {"onReason": "*", "action": "EXIT"} + ] + }, +} + +print(batch.submit_job(**job)) +``` diff --git a/ruff.toml b/ruff.toml index 82f49e2d..82c076df 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,5 +1,6 @@ [per-file-ignores] "docs/clay_over_aoi.ipynb" = ["E501"] +"scripts/worldcover/worldcover_vrt.py" = ["E501"] [format] # https://docs.astral.sh/ruff/settings/#format diff --git a/scripts/worldcover/embeddings_db.py b/scripts/worldcover/embeddings_db.py new file mode 100644 index 00000000..1f106448 --- /dev/null +++ b/scripts/worldcover/embeddings_db.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import geopandas as gpd +import lancedb +import matplotlib.pyplot as plt +from skimage import io + +# Set working directory +wd = "/home/usr/Desktop/" + +# To download the existing embeddings run aws s3 sync +# aws s3 sync s3://clay-worldcover-embeddings /my/dir/clay-worldcover-embeddings + +vector_dir = Path(wd + "clay-worldcover-embeddings/v002/2021/") + +# Create new DB structure or open existing +db = lancedb.connect(wd + "worldcoverembeddings_db") + +# Read all vector embeddings into a list +data = [] +for strip in vector_dir.glob("*.gpq"): + print(strip) + tile_df = gpd.read_parquet(strip).to_crs("epsg:3857") + + for _, row in tile_df.iterrows(): + data.append( + {"vector": row["embeddings"], "year": 2021, "bbox": row.geometry.bounds} + ) + +# Show table names +db.table_names() + +# Drop existing table if exists +db.drop_table("worldcover-2021-v001") + +# Create embeddings table and insert the vector data +tbl = db.create_table("worldcover-2021-v001", data=data, mode="overwrite") + + +# Visualize some image chips +def plot(df, cols=10): + fig, axs = plt.subplots(1, cols, figsize=(20, 10)) + + for ax, (i, row) in zip(axs.flatten(), df.iterrows()): + bbox = row["bbox"] + url = f"https://services.terrascope.be/wms/v2?SERVICE=WMS&version=1.1.1&REQUEST=GetMap&layers=WORLDCOVER_2021_S2_TCC&BBOX={','.join([str(dat) for dat in bbox])}&SRS=EPSG:3857&FORMAT=image/png&WIDTH=512&HEIGHT=512" # noqa: E501 + image = io.imread(url) + ax.imshow(image) + ax.set_axis_off() + + plt.tight_layout() + plt.show() + + +# Select a vector by index, and search 10 similar pairs, and plot +v = tbl.to_pandas()["vector"].values[10540] +result = tbl.search(query=v).limit(5).to_pandas() +plot(result, 5) diff --git a/scripts/worldcover/run.py b/scripts/worldcover/run.py new file mode 100755 index 00000000..d1458e1b --- /dev/null +++ b/scripts/worldcover/run.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 + +# import sys +# sys.path.append("/home/tam/Documents/repos/model") + +import os +import tempfile +from math import floor + +import boto3 +import einops +import geopandas as gpd +import numpy +import pyarrow as pa +import rasterio +import torch +from rasterio.windows import Window +from shapely import box +from torchvision.transforms import v2 + +from src.datamodule import ClayDataset +from src.model_clay import CLAYModule + +YEAR = int(os.environ.get("YEAR", 2020)) +DATE = f"{YEAR}-06-01" +TILE_SIZE = 12000 +CHIP_SIZE = 512 +E_W_INDEX_START = 67 +E_W_INDEX_END = 125 +N_S_INDEX_START = 24 +N_S_INDEX_END = 49 +YORIGIN = 50.0 +XORIGIN = -125.0 +PXSIZE = 8.333333333333333e-05 + +RASTER_X_SIZE = (E_W_INDEX_END - E_W_INDEX_START) * TILE_SIZE +RASTER_Y_SIZE = (N_S_INDEX_END - N_S_INDEX_START) * TILE_SIZE +NODATA = 0 +CKPT_PATH = "s3://clay-model-ckpt/v0/mae_epoch-24_val-loss-0.46.ckpt" +# CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt" +VERSION = "002" +BUCKET = "clay-worldcover-embeddings" +URL = "https://esa-worldcover-s2.s3.amazonaws.com/rgbnir/{year}/N{yidx}/ESA_WorldCover_10m_{year}_v{version}_N{yidx}W{xidx}_S2RGBNIR.tif" +WC_VERSION_LOOKUP = { + 2020: 100, + 2021: 200, +} + + +MEAN = [ + 1369.03, # red + 1597.68, # green + 1741.10, # blue + 2858.43, # nir +] +STD = [ + 2026.96, # red + 2011.88, # green + 2146.35, # blue + 2016.38, # nir +] + +grid = gpd.read_file( + # "/home/tam/Desktop/usa/esa_worldcover_grid_usa.fgb" + "https://clay-mgrs-samples.s3.amazonaws.com/esa_worldcover_grid_usa.fgb" +) + + +def tiles_and_windows(input: Window): + print("Input", input) + x_tile_index = E_W_INDEX_END - floor(input.col_off / TILE_SIZE) + x_local_off = input.col_off % TILE_SIZE + x_size = min(CHIP_SIZE, TILE_SIZE - x_local_off) + x_another = x_size < CHIP_SIZE + + y_tile_index = N_S_INDEX_END - floor(input.row_off / TILE_SIZE) + y_local_off = input.row_off % TILE_SIZE + y_size = min(CHIP_SIZE, TILE_SIZE - y_local_off) + y_another = y_size < CHIP_SIZE + + tile_id = f"N{y_tile_index}W{str(x_tile_index).zfill(3)}" + if tile_id not in grid.tile.values: + return + + result = [ + ( + URL.format( + yidx=y_tile_index, + xidx=str(x_tile_index).zfill(3), + year=YEAR, + version=WC_VERSION_LOOKUP[YEAR], + ), + Window(x_local_off, y_local_off, x_size, y_size), + ) + ] + + if x_another: + result.append( + ( + URL.format( + yidx=y_tile_index, + xidx=str(x_tile_index - 1).zfill(3), + year=YEAR, + version=WC_VERSION_LOOKUP[YEAR], + ), + Window(0, y_local_off, CHIP_SIZE - x_size, y_size), + ) + ) + if y_another: + result.append( + ( + URL.format( + yidx=y_tile_index - 1, + xidx=str(x_tile_index).zfill(3), + year=YEAR, + version=WC_VERSION_LOOKUP[YEAR], + ), + Window(x_local_off, 0, x_size, CHIP_SIZE - y_size), + ) + ) + if x_another and y_another: + result.append( + ( + URL.format( + yidx=y_tile_index - 1, + xidx=str(x_tile_index - 1).zfill(3), + year=YEAR, + version=WC_VERSION_LOOKUP[YEAR], + ), + Window(0, 0, CHIP_SIZE - x_size, CHIP_SIZE - y_size), + ) + ) + + return result + + +def make_batch(result): + pixels = [] + for url, win in result: + with rasterio.open(url) as src: + data = src.read(window=win) + if NODATA in data: + return + pixels.append(data) + transform = src.window_transform(win) + + if len(pixels) == 1: + pixels = pixels[0] + elif len(pixels) == 2: # noqa: PLR2004 + if pixels[0].shape[2] == CHIP_SIZE: + pixels = einops.pack(pixels, "b * w")[0] + else: + pixels = einops.pack(pixels, "b h *")[0] + else: + px1 = einops.pack(pixels[:2], "b w *")[0] + px2 = einops.pack(pixels[2:], "b w *")[0] + pixels = einops.pack((px1, px2), "b * w")[0] + + assert pixels.shape == (4, CHIP_SIZE, CHIP_SIZE) + + return { + "pixels": torch.as_tensor(data=[pixels], dtype=torch.float32).to( + rgb_model.device + ), + "latlon": torch.as_tensor( + data=[ds.normalize_latlon(transform[0], transform[3])] + ).to(rgb_model.device), + "timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to( + rgb_model.device + ), + } + + +index = int(os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX", 2)) + +# Setup model components +tfm = v2.Compose([v2.Normalize(mean=MEAN, std=STD)]) +ds = ClayDataset(chips_path=[], transform=tfm) + +rgb_model = CLAYModule.load_from_checkpoint( + CKPT_PATH, + mask_ratio=0.0, + band_groups={"rgb": (0, 1, 2), "nir": (3,)}, + strict=False, # ignore the extra parameters in the checkpoint +) + +xoff = index * CHIP_SIZE +yoff = 0 +embeddings = [] +all_bounds = [] +while yoff < RASTER_Y_SIZE: + result = tiles_and_windows(Window(xoff, yoff, CHIP_SIZE, CHIP_SIZE)) + + if result is None: + yoff += CHIP_SIZE + continue + + batch = make_batch(result) + if batch is None: + yoff += CHIP_SIZE + continue + + ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) = rgb_model.model.encoder(batch) + + embeddings.append(unmasked_patches.detach().cpu().numpy()) + all_bounds.append( + ( + XORIGIN + PXSIZE * xoff, + YORIGIN - PXSIZE * (yoff + CHIP_SIZE), + XORIGIN + PXSIZE * (xoff + CHIP_SIZE), + YORIGIN - PXSIZE * yoff, + ) + ) + + yoff += CHIP_SIZE + +embeddings = numpy.vstack(embeddings) + +embeddings_mean = embeddings[:, :-2, :].mean(axis=1) + +print(f"Average embeddings have shape {embeddings_mean.shape}") + +gdf = gpd.GeoDataFrame( + data={ + "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( + numpy.ascontiguousarray(embeddings_mean) + ), + }, + geometry=[box(*dat) for dat in all_bounds], # This assumes same order + crs="EPSG:4326", +) + +with tempfile.TemporaryDirectory() as tmp: + # tmp = "/home/tam/Desktop/wcctmp" + + outpath = f"{tmp}/worldcover_embeddings_{YEAR}_{index}_v{VERSION}.gpq" + print(f"Uploading embeddings to {outpath}") + + gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0") + + s3_client = boto3.client("s3") + s3_client.upload_file( + outpath, + BUCKET, + f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}", + )