Skip to content

Commit

Permalink
copy argparse reading from inference.py (GDL)
Browse files Browse the repository at this point in the history
add qgis_models
  • Loading branch information
remtav committed Mar 31, 2021
1 parent fd9e554 commit 16f38d2
Show file tree
Hide file tree
Showing 15 changed files with 5,841 additions and 41 deletions.
9 changes: 4 additions & 5 deletions config_4class.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ global:

# Inference parameters; used in inference.py --------
inference:
img_dir_or_csv_file: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet
# /path/to/model/weights/for/inference/checkpoint.pth.tar
state_dict_path: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet/checkpoint_params.pth.tar
img_dir_or_csv_file: #/path/to/img_dir_or_csv_file
state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar

# Post-processing parameters; used in post-process.py
post-processing:
Expand All @@ -36,5 +35,5 @@ post-processing:
patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns
orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°)

to_cog: True # Convert raster inference to cog (will compress with LZW)
keep_non_cog: True # if False, will delete inferences after they are converted to cog (only applies when to_cog=True)
to_cog: True # (bool) Convert raster inference to cog (will compress with LZW)
keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True)
36 changes: 36 additions & 0 deletions config_buildings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Deep learning configuration file ------------------------------------------------
# This config is to use massachusetts_buildings image out of the box WITH ONLY MANDATORY PARAMETERS.
# For that, unzip the file data/massachusetts_buildings.zip before running images_to_samples.py or other command.
# Five sections :
# 1) Global parameters; those are re-used amongst the next three operations (sampling, training and inference)
# 2) Inference parameters
# 3) Post-processing parameters

# Global parameters
global:
task: 'segmentation'
number_of_bands: 4 # Number of bands in input imagery
# Set to True if the first three channels of imagery are blue, green, red. Set to False if Red, Green, Blue
BGR_to_RGB: True
classes:
1: 'buildings'

# Inference parameters; used in inference.py --------
inference:
img_dir_or_csv_file: #/path/to/img_dir_or_csv_file
state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar

# Post-processing parameters; used in post-process.py
post-processing:
r2vect_cellsize_resamp: 0 # (int) Resample raster before vectorization with this amount of pixels. Default:0
removeholesunder: 4 # (int) Remove holes under this number in all classes
simptol: 0.75 # (float) tolerance for simplification (Douglas-Peucker). Does not apply to 'buildings' (Visvalingam)
redbenddiamtol: 3 # (int) tolerance for reduce bend algorithm. Applies to all classes
buildings: # buildings-specific parameters
recttol: 0.80 # if 1, will exclude all buildings that are not rectangles
compacttol: 0.85 # if 1, will exclude all buildings that are not a circle
patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns
orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°)

to_cog: True # (bool) Convert raster inference to cog (will compress with LZW)
keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True)
6 changes: 3 additions & 3 deletions config_roads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ global:

# Inference parameters; used in inference.py --------
inference:
img_dir_or_csv_file: #/home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet
img_dir_or_csv_file: #/path/to/img_dir_or_csv_file
state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar

# Post-processing parameters; used in post-process.py
Expand All @@ -32,5 +32,5 @@ post-processing:
patterntol: 0.75 # if 1, will exclude all buildings that don't fit perfectly under one of the hardcoded patterns
orthogonalize_ang_thresh: 20 # max angle formed by 3 vertices that will be orthogonalized (to 90° or 180°)

to_cog: True # Convert raster inference to cog (will compress with LZW)
keep_non_cog: True # if False, will delete inferences after they are converted to cog (only applies when to_cog=True)
to_cog: True # (bool) Convert raster inference to cog (will compress with LZW)
keep_non_cog: False # (bool) if False, will delete inferences after they are converted to cog (only applies when to_cog=True)
11 changes: 6 additions & 5 deletions inference_pipeline_HPC.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
#SBATCH --gpus-per-task=1

# USER VARIABLES
yaml=path/to/config.yaml
yaml=absolute/path/to/config.yaml

# SET ENVIRONMENT VARIABLES (DO NOT TOUCH)
# load cuda 11.2
. ssmuse-sh -d /fs/ssm/hpco/exp/cuda-11.2.0
export LD_LIBRARY_PATH=/fs/ssm/hpco/exp/cuda-11.2.0/cuda_11.2.0_all/lib64
export PYTHONPATH=/space/partner/nrcan/geobase/work/opt/miniconda3
export PATH=/usr/local/cuda-10.0/bin:$PATH
#export PATH=/usr/local/cuda-10.0/bin:$PATH
export MKL_THREADING_LAYER=GNU
export PYTHONPATH=/space/partner/nrcan/geobase/work/opt/miniconda3
source /space/partner/nrcan/geobase/work/opt/miniconda3/bin/activate
# Activate geo-deep-learning conda environment
conda activate /gpfs/fs3/nrcan/nrcan_geobase/work/envs/gdl_py38
conda activate gdl39-beta
#conda activate /gpfs/fs3/nrcan/nrcan_geobase/work/envs/gdl_py38

# EXECUTION STEPS
# Run inference with desired yaml
Expand All @@ -37,7 +38,7 @@ conda deactivate
conda activate qgis316

# Run post-process.py with directory containing inferences as argument (must contain file ending with '_inference.tif'!)
cd /gpfs/fs3/nrcan/nrcan_geobase/work/transfer/work/deep_learning/inference/
cd /gpfs/fs3/nrcan/nrcan_geobase/work/transfer/work/deep_learning/inference/postprocess-gdl
python post-process.py -p $yaml
conda deactivate

Expand Down
62 changes: 35 additions & 27 deletions post-process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from joblib import Parallel, delayed

from utils import read_parameters, get_key_def, load_checkpoint
from utils import read_parameters, get_key_def, load_checkpoint, compare_config_yamls


def subprocess_command(command: str):
Expand Down Expand Up @@ -70,46 +70,54 @@ def main(img_path, params):

if __name__ == '__main__':
print('\n\nStart:\n\n')
parser = argparse.ArgumentParser(usage="%(prog)s [-h] [YAML] [-i MODEL IMAGE] ",
description='Post-processing of inference created by geo-deep-learning')
parser = argparse.ArgumentParser(usage="%(prog)s [-h] [-p YAML] [-i MODEL IMAGE] ",
description='Inference and Benchmark on images using trained model')

parser.add_argument('-p', '--param', metavar='yaml_file', nargs=1,
help='Path to parameters stored in yaml')
parser.add_argument('-i', '--input', metavar='model_pth', nargs=1,
help='model_path')
parser.add_argument('-i', '--input', metavar='model_pth img_dir', nargs=2,
help='model_path and image_dir')
args = parser.parse_args()

if args.param: # if yaml file is provided as input
params = read_parameters(args.param[0]) # read yaml file into ordereddict object
model_ckpt = get_key_def('state_dict_path', params['inference'], expected_type=str)
# if a yaml is inputted, get those parameters and get model state_dict to overwrite global parameters afterwards
if args.param:
input_params = read_parameters(args.param[0])
model_ckpt = get_key_def('state_dict_path', input_params['inference'], expected_type=str)
# load checkpoint
checkpoint = load_checkpoint(model_ckpt)
if 'params' not in checkpoint.keys():
warnings.warn('No parameters found in checkpoint. Use GDL version 1.3 or more.')
else:
params = checkpoint['params']
# overwrite with inputted parameters
compare_config_yamls(yaml1=params, yaml2=input_params, update_yaml1=True)
del checkpoint
del input_params

# elif input is a model checkpoint and an image directory, we'll rely on the yaml saved inside the model (pth.tar)
elif args.input:
model_ckpt = Path(args.input[0])
params = {}
image = args.input[1]
# load checkpoint
checkpoint = load_checkpoint(model_ckpt)
if 'params' not in checkpoint.keys():
raise KeyError('No parameters found in checkpoint. Use GDL version 1.3 or more.')
else:
# set parameters for inference from those contained in checkpoint.pth.tar
params = checkpoint['params']
del checkpoint
# overwrite with inputted parameters
params['inference']['state_dict_path'] = args.input[0]
params['inference']['img_dir_or_csv_file'] = args.input[1]
num_bands = get_key_def('num_bands', params['global'], expected_type=int)
else:
print('use the help [-h] option for correct usage')
raise SystemExit

checkpoint = load_checkpoint(model_ckpt)
if 'params' not in checkpoint.keys():
raise KeyError('No parameters found in checkpoint. Use GDL version 1.3 or more.')
else:
ckpt_num_bands = checkpoint['params']['global']['number_of_bands']
num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
if not num_bands == ckpt_num_bands:
raise ValueError(f'Got number_of_bands {num_bands}. Expected {ckpt_num_bands}')
num_classes = get_key_def('num_classes', params['global'])
ckpt_num_classes = checkpoint['params']['global']['num_classes']
classes = get_key_def('classes', params['global'], expected_type=dict)
if not num_classes == ckpt_num_classes == len(classes.keys()):
raise ValueError(f'Got num_classes {num_classes}. Expected {ckpt_num_classes}')

del checkpoint

state_dict_path = get_key_def('state_dict_path', params['inference'])
working_folder = Path(state_dict_path).parent
glob_pattern = f"inference_{num_bands}bands/*_inference.tif"
ckpt_num_bands = get_key_def('num_bands', params['global'], expected_type=int)
glob_pattern = f"inference_{ckpt_num_bands}bands/*_inference.tif"
glob_pattern = f"**/*_inference.tif"
globbed_imgs_paths = list(working_folder.glob(glob_pattern))

if not globbed_imgs_paths:
Expand Down
Loading

0 comments on commit 16f38d2

Please sign in to comment.