diff --git a/config_4class.yaml b/config_4class.yaml index 867a422..858492f 100644 --- a/config_4class.yaml +++ b/config_4class.yaml @@ -20,9 +20,8 @@ global: # Inference parameters; used in inference.py -------- inference: - img_dir_or_csv_file: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet - # /path/to/model/weights/for/inference/checkpoint.pth.tar - state_dict_path: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet/checkpoint_params.pth.tar + img_dir_or_csv_file: #/path/to/img_dir_or_csv_file + state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar # Post-processing parameters; used in post-process.py post-processing: @@ -36,5 +35,5 @@ post-processing: patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°) - to_cog: True # Convert raster inference to cog (will compress with LZW) - keep_non_cog: True # if False, will delete inferences after they are converted to cog (only applies when to_cog=True) \ No newline at end of file + to_cog: True # (bool) Convert raster inference to cog (will compress with LZW) + keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True) \ No newline at end of file diff --git a/config_buildings.yaml b/config_buildings.yaml new file mode 100644 index 0000000..6bf66be --- /dev/null +++ b/config_buildings.yaml @@ -0,0 +1,36 @@ +# Deep learning configuration file ------------------------------------------------ +# This config is to use massachusetts_buildings image out of the box WITH ONLY MANDATORY PARAMETERS. +# For that, unzip the file data/massachusetts_buildings.zip before running images_to_samples.py or other command. +# Five sections : +# 1) Global parameters; those are re-used amongst the next three operations (sampling, training and inference) +# 2) Inference parameters +# 3) Post-processing parameters + +# Global parameters +global: + task: 'segmentation' + number_of_bands: 4 # Number of bands in input imagery + # Set to True if the first three channels of imagery are blue, green, red. Set to False if Red, Green, Blue + BGR_to_RGB: True + classes: + 1: 'buildings' + +# Inference parameters; used in inference.py -------- +inference: + img_dir_or_csv_file: #/path/to/img_dir_or_csv_file + state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar + +# Post-processing parameters; used in post-process.py +post-processing: + r2vect_cellsize_resamp: 0 # (int) Resample raster before vectorization with this amount of pixels. Default:0 + removeholesunder: 4 # (int) Remove holes under this number in all classes + simptol: 0.75 # (float) tolerance for simplification (Douglas-Peucker). Does not apply to 'buildings' (Visvalingam) + redbenddiamtol: 3 # (int) tolerance for reduce bend algorithm. Applies to all classes + buildings: # buildings-specific parameters + recttol: 0.80 # if 1, will exclude all buildings that are not rectangles + compacttol: 0.85 # if 1, will exclude all buildings that are not a circle + patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns + orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°) + + to_cog: True # (bool) Convert raster inference to cog (will compress with LZW) + keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True) \ No newline at end of file diff --git a/config_roads.yaml b/config_roads.yaml index 8092796..73687c7 100644 --- a/config_roads.yaml +++ b/config_roads.yaml @@ -17,7 +17,7 @@ global: # Inference parameters; used in inference.py -------- inference: - img_dir_or_csv_file: #/home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet + img_dir_or_csv_file: #/path/to/img_dir_or_csv_file state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar # Post-processing parameters; used in post-process.py @@ -32,5 +32,5 @@ post-processing: patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°) - to_cog: True # Convert raster inference to cog (will compress with LZW) - keep_non_cog: True # if False, will delete inferences after they are converted to cog (only applies when to_cog=True) \ No newline at end of file + to_cog: True # (bool) Convert raster inference to cog (will compress with LZW) + keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True) \ No newline at end of file diff --git a/inference_pipeline_HPC.sbatch b/inference_pipeline_HPC.sbatch index 826d1f4..d7f94b9 100755 --- a/inference_pipeline_HPC.sbatch +++ b/inference_pipeline_HPC.sbatch @@ -14,18 +14,19 @@ #SBATCH --gpus-per-task=1 # USER VARIABLES -yaml=path/to/config.yaml +yaml=absolute/path/to/config.yaml # SET ENVIRONMENT VARIABLES (DO NOT TOUCH) # load cuda 11.2 . ssmuse-sh -d /fs/ssm/hpco/exp/cuda-11.2.0 export LD_LIBRARY_PATH=/fs/ssm/hpco/exp/cuda-11.2.0/cuda_11.2.0_all/lib64 -export PYTHONPATH=/space/partner/nrcan/geobase/work/opt/miniconda3 -export PATH=/usr/local/cuda-10.0/bin:$PATH +#export PATH=/usr/local/cuda-10.0/bin:$PATH export MKL_THREADING_LAYER=GNU +export PYTHONPATH=/space/partner/nrcan/geobase/work/opt/miniconda3 source /space/partner/nrcan/geobase/work/opt/miniconda3/bin/activate # Activate geo-deep-learning conda environment -conda activate /gpfs/fs3/nrcan/nrcan_geobase/work/envs/gdl_py38 +conda activate gdl39-beta +#conda activate /gpfs/fs3/nrcan/nrcan_geobase/work/envs/gdl_py38 # EXECUTION STEPS # Run inference with desired yaml @@ -37,7 +38,7 @@ conda deactivate conda activate qgis316 # Run post-process.py with directory containing inferences as argument (must contain file ending with '_inference.tif'!) -cd /gpfs/fs3/nrcan/nrcan_geobase/work/transfer/work/deep_learning/inference/ +cd /gpfs/fs3/nrcan/nrcan_geobase/work/transfer/work/deep_learning/inference/postprocess-gdl python post-process.py -p $yaml conda deactivate diff --git a/post-process.py b/post-process.py index d97f939..ac9169f 100644 --- a/post-process.py +++ b/post-process.py @@ -5,7 +5,7 @@ from joblib import Parallel, delayed -from utils import read_parameters, get_key_def, load_checkpoint +from utils import read_parameters, get_key_def, load_checkpoint, compare_config_yamls def subprocess_command(command: str): @@ -70,46 +70,54 @@ def main(img_path, params): if __name__ == '__main__': print('\n\nStart:\n\n') - parser = argparse.ArgumentParser(usage="%(prog)s [-h] [YAML] [-i MODEL IMAGE] ", - description='Post-processing of inference created by geo-deep-learning') + parser = argparse.ArgumentParser(usage="%(prog)s [-h] [-p YAML] [-i MODEL IMAGE] ", + description='Inference and Benchmark on images using trained model') parser.add_argument('-p', '--param', metavar='yaml_file', nargs=1, help='Path to parameters stored in yaml') - parser.add_argument('-i', '--input', metavar='model_pth', nargs=1, - help='model_path') + parser.add_argument('-i', '--input', metavar='model_pth img_dir', nargs=2, + help='model_path and image_dir') args = parser.parse_args() - if args.param: # if yaml file is provided as input - params = read_parameters(args.param[0]) # read yaml file into ordereddict object - model_ckpt = get_key_def('state_dict_path', params['inference'], expected_type=str) + # if a yaml is inputted, get those parameters and get model state_dict to overwrite global parameters afterwards + if args.param: + input_params = read_parameters(args.param[0]) + model_ckpt = get_key_def('state_dict_path', input_params['inference'], expected_type=str) + # load checkpoint + checkpoint = load_checkpoint(model_ckpt) + if 'params' not in checkpoint.keys(): + warnings.warn('No parameters found in checkpoint. Use GDL version 1.3 or more.') + else: + params = checkpoint['params'] + # overwrite with inputted parameters + compare_config_yamls(yaml1=params, yaml2=input_params, update_yaml1=True) + del checkpoint + del input_params + + # elif input is a model checkpoint and an image directory, we'll rely on the yaml saved inside the model (pth.tar) elif args.input: model_ckpt = Path(args.input[0]) - params = {} + image = args.input[1] + # load checkpoint + checkpoint = load_checkpoint(model_ckpt) + if 'params' not in checkpoint.keys(): + raise KeyError('No parameters found in checkpoint. Use GDL version 1.3 or more.') + else: + # set parameters for inference from those contained in checkpoint.pth.tar + params = checkpoint['params'] + del checkpoint + # overwrite with inputted parameters + params['inference']['state_dict_path'] = args.input[0] params['inference']['img_dir_or_csv_file'] = args.input[1] - num_bands = get_key_def('num_bands', params['global'], expected_type=int) else: print('use the help [-h] option for correct usage') raise SystemExit - checkpoint = load_checkpoint(model_ckpt) - if 'params' not in checkpoint.keys(): - raise KeyError('No parameters found in checkpoint. Use GDL version 1.3 or more.') - else: - ckpt_num_bands = checkpoint['params']['global']['number_of_bands'] - num_bands = get_key_def('number_of_bands', params['global'], expected_type=int) - if not num_bands == ckpt_num_bands: - raise ValueError(f'Got number_of_bands {num_bands}. Expected {ckpt_num_bands}') - num_classes = get_key_def('num_classes', params['global']) - ckpt_num_classes = checkpoint['params']['global']['num_classes'] - classes = get_key_def('classes', params['global'], expected_type=dict) - if not num_classes == ckpt_num_classes == len(classes.keys()): - raise ValueError(f'Got num_classes {num_classes}. Expected {ckpt_num_classes}') - - del checkpoint - state_dict_path = get_key_def('state_dict_path', params['inference']) working_folder = Path(state_dict_path).parent - glob_pattern = f"inference_{num_bands}bands/*_inference.tif" + ckpt_num_bands = get_key_def('num_bands', params['global'], expected_type=int) + glob_pattern = f"inference_{ckpt_num_bands}bands/*_inference.tif" + glob_pattern = f"**/*_inference.tif" globbed_imgs_paths = list(working_folder.glob(glob_pattern)) if not globbed_imgs_paths: diff --git a/qgis_models/gdl-4classes.model3 b/qgis_models/gdl-4classes.model3 new file mode 100644 index 0000000..39c49b5 --- /dev/null +++ b/qgis_models/gdl-4classes.model3 @@ -0,0 +1,752 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/gdl-buildings.model3 b/qgis_models/gdl-buildings.model3 new file mode 100644 index 0000000..1eab3bd --- /dev/null +++ b/qgis_models/gdl-buildings.model3 @@ -0,0 +1,535 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/gdl-roads.model3 b/qgis_models/gdl-roads.model3 new file mode 100644 index 0000000..7ee0676 --- /dev/null +++ b/qgis_models/gdl-roads.model3 @@ -0,0 +1,261 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/r2vect.model3 b/qgis_models/r2vect.model3 new file mode 100644 index 0000000..f2877a6 --- /dev/null +++ b/qgis_models/r2vect.model3 @@ -0,0 +1,248 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/simplify-buildings.model3 b/qgis_models/simplify-buildings.model3 new file mode 100644 index 0000000..86ee8d7 --- /dev/null +++ b/qgis_models/simplify-buildings.model3 @@ -0,0 +1,947 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/simplify-forests.model3 b/qgis_models/simplify-forests.model3 new file mode 100644 index 0000000..14aa6b0 --- /dev/null +++ b/qgis_models/simplify-forests.model3 @@ -0,0 +1,705 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/simplify-hydro.model3 b/qgis_models/simplify-hydro.model3 new file mode 100644 index 0000000..da4eaee --- /dev/null +++ b/qgis_models/simplify-hydro.model3 @@ -0,0 +1,705 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/simplify-per-class.model3 b/qgis_models/simplify-per-class.model3 new file mode 100644 index 0000000..5bb79be --- /dev/null +++ b/qgis_models/simplify-per-class.model3 @@ -0,0 +1,637 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/qgis_models/simplify-roads.model3 b/qgis_models/simplify-roads.model3 new file mode 100644 index 0000000..5cda971 --- /dev/null +++ b/qgis_models/simplify-roads.model3 @@ -0,0 +1,920 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/utils.py b/utils.py index e419d01..bb85ee3 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,6 @@ +import warnings +from typing import List + import torch from ruamel_yaml import YAML @@ -57,7 +60,7 @@ def load_checkpoint(filename): ''' try: print(f"=> loading model '{filename}'\n") - # For loading external models with different structure in state dict. + # For loading external qgis_models with different structure in state dict. checkpoint = torch.load(filename, map_location='cpu') if 'model' not in checkpoint.keys(): # Place entire state_dict inside 'model' key @@ -67,3 +70,46 @@ def load_checkpoint(filename): return checkpoint except FileNotFoundError: raise FileNotFoundError(f"=> No model found at '{filename}'") + + +def compare_config_yamls(yaml1: dict, yaml2: dict, update_yaml1: bool = False) -> List: + """ + Checks if values for same keys or subkeys (max depth of 2) of two dictionaries match. + :param yaml1: (dict) first dict to evaluate + :param yaml2: (dict) second dict to evaluate + :param update_yaml1: (bool) it True, values in yaml1 will be replaced with values in yaml2, + if the latters are different + :return: dictionary of keys or subkeys for which there is a value mismatch if there is, or else returns None + """ + if not (isinstance(yaml1, dict) or isinstance(yaml2, dict)): + raise TypeError(f"Expected both yamls to be dictionaries. \n" + f"Yaml1's type is {type(yaml1)}\n" + f"Yaml2's type is {type(yaml2)}") + for section, params in yaml2.items(): # loop through main sections of config yaml ('global', 'sample', etc.) + if section not in yaml1.keys(): # create key if not in dictionary as we loop + yaml1[section] = {} + for param, val2 in params.items(): # loop through parameters of each section ('samples_size','debug_mode',...) + if param not in yaml1[section].keys(): # create key if not in dictionary as we loop + yaml1[section][param] = {} + # set to None if no value for that key + val1 = get_key_def(param, yaml1[section], default=None) + if isinstance(val2, dict): # if value is a dict, loop again to fetch end val (only recursive twice) + for subparam, subval2 in val2.items(): + if subparam not in yaml1[section][param].keys(): # create key if not in dictionary as we loop + yaml1[section][param][subparam] = {} + # set to None if no value for that key + subval1 = get_key_def(subparam, yaml1[section][param], default=None) + if subval2 != subval1: + # if value doesn't match between yamls, emit warning + warnings.warn(f"YAML value mismatch: section \"{section}\", key \"{param}/{subparam}\"\n" + f"Current yaml value: \"{subval1}\"\nHDF5s yaml value: \"{subval2}\"\n") + if update_yaml1: # update yaml1 with subvalue of yaml2 + yaml1[section][param][subparam] = subval2 + warnings.warn(f'Value in yaml1 updated') + elif val2 != val1: + warnings.warn(f"YAML value mismatch: section \"{section}\", key \"{param}\"\n" + f"Current yaml value: \"{val2}\"\nHDF5s yaml value: \"{val1}\"\n" + f"Problems may occur.") + if update_yaml1: # update yaml1 with value of yaml2 + yaml1[section][param] = val2 + print(f'Value in yaml1 updated')