Skip to content

Commit

Permalink
feat: Non Visual QA script (TDE-309) (#40)
Browse files Browse the repository at this point in the history
* refactor: implement a reusable component to run a GDAL command (with AWS permissions)

* wip

* feat: Non Visual QA script running on single file

* test: add Non Visual QA Unit Tests for the check methods

* feat: Non Visual QA script accepts multiple files
  • Loading branch information
paulfouquet authored Jul 12, 2022
1 parent 34c210a commit 8411779
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 22 deletions.
4 changes: 4 additions & 0 deletions scripts/aws_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def get_bucket_name_from_path(path: str) -> str:
def parse_path(path: str) -> Tuple[str, str]:
parse = urlparse(path, allow_fragments=False)
return parse.netloc, parse.path


def is_s3(path: str) -> bool:
return path.startswith("s3://")
5 changes: 5 additions & 0 deletions scripts/file_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def get_file_name_from_path(path: str) -> str:
return os.path.basename(path)
74 changes: 74 additions & 0 deletions scripts/gdal_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import subprocess
from typing import List

from aws_helper import get_bucket_name_from_path, get_credentials, is_s3
from linz_logger import get_log


def get_vfs_path(path: str) -> str:
"""Make the path as a GDAL Virtual File Systems path.
Args:
path (str): a path to a file.
Returns:
str: the path modified to comply with the corresponding storage service.
"""
return path.replace("s3://", "/vsis3/")


def command_to_string(command: List[str]) -> str:
"""Format the command, each arguments separated by a white space.
Args:
command (List[str]): each arguments of the command as a string in a list.
Returns:
str: the formatted command.
"""
return " ".join(command)


def run_gdal(command: List[str], input_file: str = "", output_file: str = "") -> "subprocess.CompletedProcess[bytes]":
"""Run the GDAL command. The permissions to access to the input file are applied to the gdal environment.
Args:
command (List[str]): each arguments of the GDAL command.
input_file (str, optional): the input file path. Defaults to "".
output_file (str, optional): the output file path. Defaults to "".
Raises:
cpe: CalledProcessError is raised if something goes wrong during the execution of the command.
Returns:
subprocess.CompletedProcess: the output process.
"""
gdal_env = os.environ.copy()

if input_file:
if is_s3(input_file):
# Set the credentials for GDAL to be able to read the input file
credentials = get_credentials(get_bucket_name_from_path(input_file))
gdal_env["AWS_ACCESS_KEY_ID"] = credentials.access_key
gdal_env["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key
gdal_env["AWS_SESSION_TOKEN"] = credentials.token
command.append(get_vfs_path(input_file))

if output_file:
command.append(output_file)

try:
get_log().debug("run_gdal", command=command_to_string(command))
proc = subprocess.run(
command,
env=gdal_env,
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as cpe:
get_log().error("run_gdal_failed", command=command_to_string(command))
raise cpe
get_log().debug("run_gdal_translate_succeded", command=command_to_string(command))

return proc
124 changes: 124 additions & 0 deletions scripts/non_visual_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import json
from typing import Any, Dict, List

from format_source import format_source
from gdal_helper import run_gdal
from linz_logger import get_log


def check_no_data(gdalinfo: Dict[str, Any], errors_list: List[str]) -> None:
"""Add an error in errors_list if there is no "noDataValue" or the "noDataValue" is not equal to 255 in the "bands".
Args:
gdalinfo (Dict[str, Any]): JSON return of gdalinfo in a Python Dictionary.
errors_list (List[str]): List of errors as strings.
"""
bands = gdalinfo["bands"]
if "noDataValue" in bands[0]:
current_nodata_val = bands[0]["noDataValue"]
if current_nodata_val != 255:
errors_list.append(f"noDataValue is {int(current_nodata_val)} not 255")
else:
errors_list.append("noDataValue not set")


def check_band_count(gdalinfo: Dict[str, Any], errors_list: List[str]) -> None:
"""Add an error in errors_list if there is no exactly 3 bands found.
Args:
gdalinfo (Dict[str, Any]): JSON returned by gdalinfo as a Python Dictionary.
errors_list (List[str]): List of errors as strings.
"""
bands = gdalinfo["bands"]
bands_num = len(bands)
if bands_num != 3:
errors_list.append(f"not 3 bands, {bands_num} bands found")


def check_srs(gdalsrsinfo: bytes, gdalsrsinfo_tif: bytes, errors_list: List[str]) -> None:
"""Add an error in errors_list if gdalsrsinfo and gdalsrsinfo_tif values are different.
Args:
gdalsrsinfo (str): Value returned by gdalsrsinfo as a string.
gdalsrsinfo_tif (str): Value returned by gdalsrsinfo for the tif as a string.
errors_list (List[str]): List of errors as strings.
"""
if gdalsrsinfo_tif != gdalsrsinfo:
errors_list.append("different srs")


def check_color_interpretation(gdalinfo: Dict[str, Any], errors_list: List[str]) -> None:
bands = gdalinfo["bands"]
missing_bands = []
band_colour_ints = {1: "Red", 2: "Green", 3: "Blue"}
n = 1
for band in bands:
colour_int = band["colorInterpretation"]
if n in band_colour_ints:
if colour_int != band_colour_ints[n]:
missing_bands.append(f"band {n} {colour_int}")
else:
missing_bands.append(f"band {n} {colour_int}")
n += 1
if missing_bands:
missing_bands.sort()
errors_list.append(f"unexpected color interpretation bands; {', '.join(missing_bands)}")


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--source", dest="source", nargs="+", required=True)
arguments = parser.parse_args()
source = arguments.source

source = format_source(source)

# Get srs
gdalsrsinfo_command = ["gdalsrsinfo", "-o", "wkt", "EPSG:2193"]
gdalsrsinfo_result = run_gdal(gdalsrsinfo_command)
if gdalsrsinfo_result.stderr:
raise Exception(
f"Error trying to retrieve srs from epsg code, no files have been checked\n{gdalsrsinfo_result.stderr!r}"
)
srs = gdalsrsinfo_result.stdout

for file in source:
gdalinfo_command = ["gdalinfo", "-stats", "-json"]
gdalinfo_process = run_gdal(gdalinfo_command, file)
gdalinfo_result = {}
try:
gdalinfo_result = json.loads(gdalinfo_process.stdout)
except json.JSONDecodeError as e:
get_log().error("load_gdalinfo_result_error", file=file, error=e)
continue

gdalinfo_errors = gdalinfo_process.stderr

# Check result
errors: List[str] = []
# No data
check_no_data(gdalinfo_result, errors)

# Band count
check_band_count(gdalinfo_result, errors)

# srs
gdalsrsinfo_tif_command = ["gdalsrsinfo", "-o", "wkt"]
gdalsrsinfo_tif_result = run_gdal(gdalsrsinfo_tif_command, file)
check_srs(srs, gdalsrsinfo_tif_result.stdout, errors)

# Color interpretation
check_color_interpretation(gdalinfo_result, errors)

# gdal errors
errors.append(f"{gdalinfo_errors!r}")

if len(errors) > 0:
get_log().info("non_visual_qa_errors_found", file=file, result=errors)
else:
get_log().info("non_visual_qa_no_error", file=file)


if __name__ == "__main__":
main()
28 changes: 6 additions & 22 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
import os
import subprocess
import tempfile

from aws_helper import get_bucket, get_credentials, parse_path
from aws_helper import get_bucket, parse_path
from file_helper import get_file_name_from_path
from format_source import format_source
from gdal_helper import run_gdal
from linz_logger import get_log

parser = argparse.ArgumentParser()
Expand All @@ -25,20 +26,10 @@
for file in source:
with tempfile.TemporaryDirectory() as tmp_dir:
src_bucket_name, src_file_path = parse_path(file)
get_log().debug("processing_file", bucket=src_bucket_name, file_path=src_file_path)
standardized_file_name = f"standardized_{os.path.basename(src_file_path)}"
standardized_file_name = f"standardized_{get_file_name_from_path(src_file_path)}"
tmp_file_path = os.path.join(tmp_dir, standardized_file_name)
src_gdal_path = file.replace("s3://", "/vsis3/")

# Set the credentials for GDAL to be able to read the source file
credentials = get_credentials(src_bucket_name)
gdal_env["AWS_ACCESS_KEY_ID"] = credentials.access_key
gdal_env["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key
gdal_env["AWS_SESSION_TOKEN"] = credentials.token

# Run GDAL to standardized the file
get_log().debug("run_gdal_translate", src=src_gdal_path, output=tmp_file_path)
gdal_command = [
command = [
"gdal_translate",
"-q",
"-scale",
Expand All @@ -58,15 +49,8 @@
"3",
"-co",
"compress=lzw",
src_gdal_path,
tmp_file_path,
]
try:
proc = subprocess.run(gdal_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=gdal_env, check=True)
except subprocess.CalledProcessError as cpe:
get_log().error("run_gdal_translate_failed", command=" ".join(gdal_command))
raise cpe
get_log().debug("run_gdal_translate_succeded", command=" ".join(gdal_command))
run_gdal(command, file, tmp_file_path)

# Upload the standardized file to destination
dst_file_path = os.path.join(dst_path, standardized_file_name).strip("/")
Expand Down
Loading

0 comments on commit 8411779

Please sign in to comment.