Skip to content

Commit

Permalink
feat: create polygon script return temporary file path (#47)
Browse files Browse the repository at this point in the history
* feat: create polygon script return temporary file path

* fix: main return value and formatting
  • Loading branch information
paulfouquet authored Jul 12, 2022
1 parent 30faaa1 commit 34c210a
Showing 1 changed file with 44 additions and 50 deletions.
94 changes: 44 additions & 50 deletions scripts/create_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,12 @@
from collections import Counter
from urllib.parse import urlparse

from aws_helper import get_bucket, get_bucket_name_from_path
from aws_helper import get_bucket
from linz_logger import get_log

# osgeo is embbed in the Docker image
from osgeo import gdal # pylint: disable=import-error

logger = get_log()

parser = argparse.ArgumentParser()
parser.add_argument("--uri", dest="uri", required=True)
parser.add_argument("--destination", dest="destination", required=True)
arguments = parser.parse_args()
uri = arguments.uri
destination = arguments.destination

# Split the s3 destination path
destination_bucket_name = get_bucket_name_from_path(destination)
destination_path = destination.replace("s3://", "").replace(f"{destination_bucket_name}/", "")


def create_mask(file_path: str, mask_dst: str) -> None:
set_srs_command = f'gdal_edit.py -a_srs EPSG:2193 "{file_path}"'
Expand Down Expand Up @@ -50,42 +37,49 @@ def get_pixel_count(file_path: str) -> int:
return data_pixels_count


with tempfile.TemporaryDirectory() as tmp_dir:
source_file_name = os.path.basename(uri)
# Download the file
if str(uri).startswith("s3://"):
uri_parse = urlparse(uri, allow_fragments=False)
bucket_name = uri_parse.netloc
bucket = get_bucket(bucket_name)
uri = os.path.join(tmp_dir, "temp.tif")
logger.debug(
"download_file", source=uri_parse.path[1:], bucket=bucket_name, destination=uri, sourceFileName=source_file_name
)
bucket.download_file(uri_parse.path[1:], uri)
def main() -> str:
logger = get_log()

parser = argparse.ArgumentParser()
parser.add_argument("--source", dest="source", required=True)
arguments = parser.parse_args()
source = arguments.source

with tempfile.TemporaryDirectory() as tmp_dir:
source_file_name = os.path.basename(source)
# Download the file
if str(source).startswith("s3://"):
uri_parse = urlparse(source, allow_fragments=False)
bucket_name = uri_parse.netloc
bucket = get_bucket(bucket_name)
source = os.path.join(tmp_dir, "temp.tif")
logger.debug(
"download_file",
source=uri_parse.path[1:],
bucket=bucket_name,
destination=source,
sourceFileName=source_file_name,
)
bucket.download_file(uri_parse.path[1:], source)

# Run create_mask
logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=source)
mask_file = os.path.join(tmp_dir, "mask.tif")
create_mask(source, mask_file)

# Run create_polygon
data_px_count = get_pixel_count(mask_file)
if data_px_count == 0:
# exclude extents if tif is all white or black
logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}")
else:
destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson"
temp_file_path = os.path.join(tmp_dir, destination_file_name)
polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON'
os.system(polygonize_command)

# Run create_mask
logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=uri)
mask_file = os.path.join(tmp_dir, "mask.tif")
create_mask(uri, mask_file)
return temp_file_path

# Run create_polygon
data_px_count = get_pixel_count(mask_file)
if data_px_count == 0:
# exclude extents if tif is all white or black
logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}")
else:
destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson"
temp_file_path = os.path.join(tmp_dir, destination_file_name)
polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON'
os.system(polygonize_command)

# Upload shape file
destination_bucket = get_bucket(destination_bucket_name)
destination_file_path = os.path.join(destination_path, destination_file_name)
logger.debug("upload_start", destinationBucket=destination_bucket_name, destinationFile=destination_file_path)
try:
destination_bucket.upload_file(temp_file_path, destination_file_path)
except Exception as e:
logger.debug("upload_error", err=e)
raise e
logger.debug("upload_end", destinationBucket=destination_bucket_name, destinationFile=destination_file_path)
if __name__ == "__main__":
main()

0 comments on commit 34c210a

Please sign in to comment.