diff --git a/scripts/create_polygons.py b/scripts/create_polygons.py index ba15a11e8..833838926 100644 --- a/scripts/create_polygons.py +++ b/scripts/create_polygons.py @@ -4,25 +4,12 @@ from collections import Counter from urllib.parse import urlparse -from aws_helper import get_bucket, get_bucket_name_from_path +from aws_helper import get_bucket from linz_logger import get_log # osgeo is embbed in the Docker image from osgeo import gdal # pylint: disable=import-error -logger = get_log() - -parser = argparse.ArgumentParser() -parser.add_argument("--uri", dest="uri", required=True) -parser.add_argument("--destination", dest="destination", required=True) -arguments = parser.parse_args() -uri = arguments.uri -destination = arguments.destination - -# Split the s3 destination path -destination_bucket_name = get_bucket_name_from_path(destination) -destination_path = destination.replace("s3://", "").replace(f"{destination_bucket_name}/", "") - def create_mask(file_path: str, mask_dst: str) -> None: set_srs_command = f'gdal_edit.py -a_srs EPSG:2193 "{file_path}"' @@ -50,42 +37,49 @@ def get_pixel_count(file_path: str) -> int: return data_pixels_count -with tempfile.TemporaryDirectory() as tmp_dir: - source_file_name = os.path.basename(uri) - # Download the file - if str(uri).startswith("s3://"): - uri_parse = urlparse(uri, allow_fragments=False) - bucket_name = uri_parse.netloc - bucket = get_bucket(bucket_name) - uri = os.path.join(tmp_dir, "temp.tif") - logger.debug( - "download_file", source=uri_parse.path[1:], bucket=bucket_name, destination=uri, sourceFileName=source_file_name - ) - bucket.download_file(uri_parse.path[1:], uri) +def main() -> str: + logger = get_log() + + parser = argparse.ArgumentParser() + parser.add_argument("--source", dest="source", required=True) + arguments = parser.parse_args() + source = arguments.source + + with tempfile.TemporaryDirectory() as tmp_dir: + source_file_name = os.path.basename(source) + # Download the file + if str(source).startswith("s3://"): + uri_parse = urlparse(source, allow_fragments=False) + bucket_name = uri_parse.netloc + bucket = get_bucket(bucket_name) + source = os.path.join(tmp_dir, "temp.tif") + logger.debug( + "download_file", + source=uri_parse.path[1:], + bucket=bucket_name, + destination=source, + sourceFileName=source_file_name, + ) + bucket.download_file(uri_parse.path[1:], source) + + # Run create_mask + logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=source) + mask_file = os.path.join(tmp_dir, "mask.tif") + create_mask(source, mask_file) + + # Run create_polygon + data_px_count = get_pixel_count(mask_file) + if data_px_count == 0: + # exclude extents if tif is all white or black + logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}") + else: + destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson" + temp_file_path = os.path.join(tmp_dir, destination_file_name) + polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON' + os.system(polygonize_command) - # Run create_mask - logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=uri) - mask_file = os.path.join(tmp_dir, "mask.tif") - create_mask(uri, mask_file) + return temp_file_path - # Run create_polygon - data_px_count = get_pixel_count(mask_file) - if data_px_count == 0: - # exclude extents if tif is all white or black - logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}") - else: - destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson" - temp_file_path = os.path.join(tmp_dir, destination_file_name) - polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON' - os.system(polygonize_command) - # Upload shape file - destination_bucket = get_bucket(destination_bucket_name) - destination_file_path = os.path.join(destination_path, destination_file_name) - logger.debug("upload_start", destinationBucket=destination_bucket_name, destinationFile=destination_file_path) - try: - destination_bucket.upload_file(temp_file_path, destination_file_path) - except Exception as e: - logger.debug("upload_error", err=e) - raise e - logger.debug("upload_end", destinationBucket=destination_bucket_name, destinationFile=destination_file_path) +if __name__ == "__main__": + main()