Skip to content

Commit

Permalink
post-process.py takes yaml as input
Browse files Browse the repository at this point in the history
  • Loading branch information
remtav committed Mar 26, 2021
1 parent b3577d3 commit 5799138
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 34 deletions.
33 changes: 33 additions & 0 deletions config_4class.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Deep learning configuration file ------------------------------------------------
# This config is to use massachusetts_buildings image out of the box WITH ONLY MANDATORY PARAMETERS.
# For that, unzip the file data/massachusetts_buildings.zip before running images_to_samples.py or other command.
# Five sections :
# 1) Global parameters; those are re-used amongst the next three operations (sampling, training and inference)
# 2) Inference parameters
# 3) Post-processing parameters

# Global parameters
global:
task: 'segmentation'
num_classes: 4
number_of_bands: 4
# Set to True if the first three channels of imagery are blue, green, red. Set to False if Red, Green, Blue
BGR_to_RGB: True
classes:
1: 'forests'
2: 'hydro'
3: 'roads'
4: 'buildings'

# Inference parameters; used in inference.py --------
inference:
img_dir_or_csv_file: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet
# /path/to/model/weights/for/inference/checkpoint.pth.tar
state_dict_path: /home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet/checkpoint_params.pth.tar

# Post-processing parameters; used in post-process.py
post-processing:
r2vect_cellsize_resamp: 0
orthogonalize_ang_thresh: 20
to_cog: True
keep_non_cog: True
28 changes: 28 additions & 0 deletions config_roads.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Deep learning configuration file ------------------------------------------------
# This config is to use massachusetts_buildings image out of the box WITH ONLY MANDATORY PARAMETERS.
# For that, unzip the file data/massachusetts_buildings.zip before running images_to_samples.py or other command.
# Five sections :
# 1) Global parameters; those are re-used amongst the next three operations (sampling, training and inference)
# 2) Inference parameters
# 3) Post-processing parameters

# Global parameters
global:
task: 'segmentation'
number_of_bands: 4 # Number of bands in input imagery
# Set to True if the first three channels of imagery are blue, green, red. Set to False if Red, Green, Blue
BGR_to_RGB: True
classes:
1: 'roads'

# Inference parameters; used in inference.py --------
inference:
img_dir_or_csv_file: #/home/remi/Documents/inferences/test_13/cayouche_RGBN_test13_unet
state_dict_path: # /path/to/model/weights/for/inference/checkpoint.pth.tar

# Post-processing parameters; used in post-process.py
post-processing:
r2vect_cellsize_resamp: 0
orthogonalize_ang_thresh: 20
to_cog: True
keep_non_cog: True
88 changes: 55 additions & 33 deletions post-process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from joblib import Parallel, delayed

from utils import read_parameters, get_key_def
from utils import read_parameters, get_key_def, load_checkpoint


def subprocess_command(command: str):
Expand All @@ -22,33 +22,37 @@ def main(img_path, params):

# post-processing parameters
# FIXME: as yaml input
classes = [(1, 'forest'), (2, 'hydro'), (3, 'roads'), (4, 'buildings')]
cell_size_resamp = 0
orthogonalize_ang_thresh = 20
to_cog = False
keep_non_cog = True
classes = get_key_def('classes', params['global'], expected_type=dict)
r2v_cellsize_resamp = get_key_def('r2vect_cellsize_resamp', params['post-processing'], default=0, expected_type=int)
orthog_ang_thres = get_key_def('orthogonalize_ang_thresh', params['post-processing'], default=20, expected_type=int)
to_cog = get_key_def('to_cog', params['post-processing'], default=True, expected_type=bool)
keep_non_cog = get_key_def('keep_non_cog', params['post-processing'], default=True, expected_type=bool)

# validate inputted classes
if 0 in classes.keys():
warnings.warn("Are you sure value 0 is of interest? It is usually used to set background class, "
"i.e. non-relevant class. Will add 1 to all class values inputted, e.g. 0,1,2,3 --> 1,2,3,4")
classes = {cl_val + 1: name for cl_val, name in classes}

# set name of output gpkg: myinference.tif will become myinference.gpkg
final_gpkg = Path(img_path).parent / f'{Path(img_path).stem}.gpkg'
if final_gpkg.is_file():
warnings.warn(f'Output geopackage exists: {final_gpkg}. Skipping to next inference...')
else:
if (len(classes)) != 4:
raise NotImplementedError
command = f'qgis_process run model:gdl-{len(classes)}classes -- ' \
f'srcinfraster="{img_path}" ' \
f'r2vcellsizeresamp={cell_size_resamp} ' \
f'native:package_1:dest-gpkg={final_gpkg} '

# for attrnum, class_name in classes:
# if attrnum == 0:
# warnings.warn("Are you sure value 0 is of interest? It is usually used to set background class, "
# "i.e. non-relevant class")
#
# command += f'attnum{attrnum}={attrnum} ' \
# f'class{attrnum}=\'{class_name}\' '

subprocess_command(command)
if len(classes.keys()) == 1 and classes[1] == 'roads':
command = f'qgis_process run model:gdl-roads -- ' \
f'inputraster="{img_path}" ' \
f'r2vcellsizeresamp={r2v_cellsize_resamp} ' \
f'native:package_1:dest-gpkg={final_gpkg}'
elif len(classes) == 4:
command = f'qgis_process run model:gdl-{len(classes)}classes -- ' \
f'srcinfraster="{img_path}" ' \
f'r2vcellsizeresamp={r2v_cellsize_resamp} ' \
f'native:package_1:dest-gpkg={final_gpkg}'
else:
raise NotImplementedError(f'Cannot post-process inference with {len(classes.keys())} classes')

subprocess_command(command)

# COG
if to_cog:
Expand All @@ -75,25 +79,43 @@ def main(img_path, params):
help='model_path')
args = parser.parse_args()

if args.param:
params = read_parameters(args.param_file)
if args.param: # if yaml file is provided as input
params = read_parameters(args.param[0]) # read yaml file into ordereddict object
model_ckpt = get_key_def('state_dict_path', params['inference'], expected_type=str)
elif args.input:
model = Path(args.input[0])

params = {'inference': {}}
params['inference']['state_dict_path'] = args.input[0]

model_ckpt = Path(args.input[0])
params = {}
params['inference']['img_dir_or_csv_file'] = args.input[1]
num_bands = get_key_def('num_bands', params['global'], expected_type=int)
else:
print('use the help [-h] option for correct usage')
raise SystemExit

working_folder = Path(params['inference']['state_dict_path']).parent
#num_bands = get_key_def('num_bands', params['global'], expected_type=int)
num_bands = 4
checkpoint = load_checkpoint(model_ckpt)
if 'params' not in checkpoint.keys():
raise KeyError('No parameters found in checkpoint. Use GDL version 1.3 or more.')
else:
ckpt_num_bands = checkpoint['params']['global']['number_of_bands']
num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
if not num_bands == ckpt_num_bands:
raise ValueError(f'Got number_of_bands {num_bands}. Expected {ckpt_num_bands}')
num_classes = get_key_def('num_classes', params['global'])
ckpt_num_classes = checkpoint['params']['global']['num_classes']
classes = get_key_def('classes', params['global'], expected_type=dict)
if not num_classes == ckpt_num_classes == len(classes.keys()):
raise ValueError(f'Got num_classes {num_classes}. Expected {ckpt_num_classes}')

del checkpoint

state_dict_path = get_key_def('state_dict_path', params['inference'])
working_folder = Path(state_dict_path).parent
glob_pattern = f"inference_{num_bands}bands/*_inference.tif"
globbed_imgs_paths = list(working_folder.glob(glob_pattern))

print(f"Found {len(globbed_imgs_paths)} inferences to post-process")
if not globbed_imgs_paths:
raise FileNotFoundError(f'No tif images found to post-process in {working_folder}')
else:
print(f"Found {len(globbed_imgs_paths)} inferences to post-process")

Parallel(n_jobs=len(globbed_imgs_paths))(delayed(main)(file, params=params) for file in globbed_imgs_paths)
#main(params)
22 changes: 21 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from ruamel_yaml import YAML


Expand Down Expand Up @@ -46,4 +47,23 @@ def get_key_def(key, config, default=None, msg=None, delete=False, expected_type
assert isinstance(val, expected_type), f"{val} is of type {type(val)}, expected {expected_type}"
if delete:
del config[key]
return val
return val


def load_checkpoint(filename):
''' Loads checkpoint from provided path
:param filename: path to checkpoint as .pth.tar or .pth
:return: (dict) checkpoint ready to be loaded into model instance
'''
try:
print(f"=> loading model '{filename}'\n")
# For loading external models with different structure in state dict.
checkpoint = torch.load(filename, map_location='cpu')
if 'model' not in checkpoint.keys():
# Place entire state_dict inside 'model' key
temp_checkpoint = {'model': {k: v for k, v in checkpoint.items()}}
del checkpoint
checkpoint = temp_checkpoint
return checkpoint
except FileNotFoundError:
raise FileNotFoundError(f"=> No model found at '{filename}'")

0 comments on commit 5799138

Please sign in to comment.