Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create polygon script return temporary file path #47

Merged
merged 2 commits into from
Jul 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 44 additions & 50 deletions scripts/create_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,12 @@
from collections import Counter
from urllib.parse import urlparse

from aws_helper import get_bucket, get_bucket_name_from_path
from aws_helper import get_bucket
from linz_logger import get_log

# osgeo is embbed in the Docker image
from osgeo import gdal # pylint: disable=import-error

logger = get_log()

parser = argparse.ArgumentParser()
parser.add_argument("--uri", dest="uri", required=True)
parser.add_argument("--destination", dest="destination", required=True)
arguments = parser.parse_args()
uri = arguments.uri
destination = arguments.destination

# Split the s3 destination path
destination_bucket_name = get_bucket_name_from_path(destination)
destination_path = destination.replace("s3://", "").replace(f"{destination_bucket_name}/", "")


def create_mask(file_path: str, mask_dst: str) -> None:
set_srs_command = f'gdal_edit.py -a_srs EPSG:2193 "{file_path}"'
Expand Down Expand Up @@ -50,42 +37,49 @@ def get_pixel_count(file_path: str) -> int:
return data_pixels_count


with tempfile.TemporaryDirectory() as tmp_dir:
source_file_name = os.path.basename(uri)
# Download the file
if str(uri).startswith("s3://"):
uri_parse = urlparse(uri, allow_fragments=False)
bucket_name = uri_parse.netloc
bucket = get_bucket(bucket_name)
uri = os.path.join(tmp_dir, "temp.tif")
logger.debug(
"download_file", source=uri_parse.path[1:], bucket=bucket_name, destination=uri, sourceFileName=source_file_name
)
bucket.download_file(uri_parse.path[1:], uri)
def main() -> str:
logger = get_log()

parser = argparse.ArgumentParser()
parser.add_argument("--source", dest="source", required=True)
arguments = parser.parse_args()
source = arguments.source

with tempfile.TemporaryDirectory() as tmp_dir:
source_file_name = os.path.basename(source)
# Download the file
if str(source).startswith("s3://"):
uri_parse = urlparse(source, allow_fragments=False)
bucket_name = uri_parse.netloc
bucket = get_bucket(bucket_name)
source = os.path.join(tmp_dir, "temp.tif")
logger.debug(
"download_file",
source=uri_parse.path[1:],
bucket=bucket_name,
destination=source,
sourceFileName=source_file_name,
)
bucket.download_file(uri_parse.path[1:], source)

# Run create_mask
logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=source)
mask_file = os.path.join(tmp_dir, "mask.tif")
create_mask(source, mask_file)

# Run create_polygon
data_px_count = get_pixel_count(mask_file)
if data_px_count == 0:
# exclude extents if tif is all white or black
logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}")
else:
destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson"
temp_file_path = os.path.join(tmp_dir, destination_file_name)
polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON'
os.system(polygonize_command)

# Run create_mask
logger.debug("create_mask", source=uri_parse.path[1:], bucket=bucket_name, destination=uri)
mask_file = os.path.join(tmp_dir, "mask.tif")
create_mask(uri, mask_file)
return temp_file_path

# Run create_polygon
data_px_count = get_pixel_count(mask_file)
if data_px_count == 0:
# exclude extents if tif is all white or black
logger.debug(f"- data_px_count was zero in create_mask function for the tif {mask_file}")
else:
destination_file_name = os.path.splitext(source_file_name)[0] + ".geojson"
temp_file_path = os.path.join(tmp_dir, destination_file_name)
polygonize_command = f'gdal_polygonize.py -q "{mask_file}" "{temp_file_path}" -f GeoJSON'
os.system(polygonize_command)

# Upload shape file
destination_bucket = get_bucket(destination_bucket_name)
destination_file_path = os.path.join(destination_path, destination_file_name)
logger.debug("upload_start", destinationBucket=destination_bucket_name, destinationFile=destination_file_path)
try:
destination_bucket.upload_file(temp_file_path, destination_file_path)
except Exception as e:
logger.debug("upload_error", err=e)
raise e
logger.debug("upload_end", destinationBucket=destination_bucket_name, destinationFile=destination_file_path)
if __name__ == "__main__":
main()