-
Notifications
You must be signed in to change notification settings - Fork 12
/
cli.py
394 lines (306 loc) · 13 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import click, os, logging, yaml, json
from datetime import datetime as dt
from click import command, option, Option, UsageError
from deepsentinel.utils.utils import get_from_dict, set_in_dict, make_nested_dict
logging.basicConfig(level=logging.INFO)
@click.group()
def cli():
pass
@cli.command()
@click.option('--conf', default=os.path.join(os.getcwd(),'conf','DATA_CONFIG.yaml'), help='path to DATA_CONFIG.yaml')
@click.option('--n-orbits', help='The number of orbits to spread the simulated points over.', type=int) # one of N_orbits, end_date, pts per orbit
@click.option('--end-date', help='the end date to stop sampling points, as YYYY-mm-dd', type=str)
@click.option('--iso2', help='A comma-separated list of iso-a2 country codes for geographic subsampling', type=str)
@click.argument('start_date', type=str)
@click.argument('n-points', type=int)
@click.argument('name', type=str)
def generate_points(name, n_points, start_date, iso2, end_date, n_orbits, conf):
"""
Seed points for a new dataset.
\b
PARAMETERS
----------
START_DATE: str
The start date for data collection in the form YYYY-mm-dd.
N_POINTS: int
The number of data points to generate.
NAME: str
The name of the new dataset.
"""
from deepsentinel.utils.point_generator import PointGenerator
logger = logging.getLogger('GENERATE_POINTS')
# error check either end_date OR n_orbits
assert (end_date or n_orbits), 'Only one of n_orbits or end_date must be provided.'
assert not (end_date and n_orbits), 'Only one of n_orbits or end_date must be provided.'
# error check date formats
try:
start_date = dt.strptime(start_date,'%Y-%m-%d')
except:
raise ValueError('Ensure start_date is in the correct format, YYYY-mm-dd')
if end_date!=None:
try:
end_date = dt.strptime(end_date,'%Y-%m-%d')
except:
raise ValueError('Ensure end_date is in the correct format, YYYY-mm-dd')
logger.info('Generating points with:')
logger.info(f'NAME:{name}')
logger.info(f'N_POINTS:{n_points}')
logger.info(f'START_DATE:{start_date}')
logger.info(f'iso2:{iso2}')
logger.info(f'end_date:{end_date}')
logger.info(f'n_orbits:{n_orbits}')
logger.info(f'conf:{conf}')
if iso2:
iso2 = iso2.split(',')
logger.info('Initialising generator')
generator=PointGenerator(iso_geographies=iso2, conf=conf)
if not n_orbits: # get n_orbits from end_date
n_orbits = (end_date-start_date).days // generator.CONFIG['orbit_period']
pts_per_orbit = n_points//n_orbits + 1
else: # have n_orbits, get pts_per_orbit
pts_per_orbit = n_points//n_orbits + 1
logger.info(f'Running generator for {name} from {start_date.isoformat()} for {n_orbits} orbits with {pts_per_orbit} points per orbit')
generator.main_generator(start_date, n_orbits, pts_per_orbit,name)
@cli.command()
@click.option('--conf', default=os.path.join(os.getcwd(),'conf','DATA_CONFIG.yaml'), help='path to DATA_CONFIG.yaml')
@click.argument('gdf_path', type=str)
@click.argument('name', type=str)
@click.argument('start_date', type=str)
@click.argument('end-date', type=str)
def geopandas_to_points(gdf_path, name, start_date, end_date, conf):
"""
Seed points for a new dataset.
\b
PARAMETERS
----------
GDF_PATH: str
The path to the GeoPandas GeoDataFrame to load (with gpd.read_file).
NAME: str
The name of the new dataset.
START_DATE: str
The start date for data collection in the form YYYY-mm-dd.
END_DATE: str
The end date for data collection in the form YYYY-mm-dd.
"""
from deepsentinel.utils.gdf2points import GDF2Points
logger = logging.getLogger('POINTS_FROM_GDF')
# error check date formats
try:
start_date = dt.strptime(start_date,'%Y-%m-%d')
except:
raise ValueError('Ensure start_date is in the correct format, YYYY-mm-dd')
try:
end_date = dt.strptime(end_date,'%Y-%m-%d')
except:
raise ValueError('Ensure end_date is in the correct format, YYYY-mm-dd')
logger.info('Generating points with:')
logger.info(f'GDF_PATH:{gdf_path}')
logger.info(f'NAME:{name}')
logger.info(f'START_DATE:{start_date}')
logger.info(f'END_DATE:{end_date}')
logger.info(f'conf:{conf}')
logger.info('Initialising generator')
generator=GDF2Points(conf=conf)
logger.info(f'Running generator for {name} sampling tiles for {gdf_path} from {start_date.isoformat()} to {end_date.isoformat()}')
generator.generate_from_gdf(gdf_path, start_date, end_date, name)
@cli.command()
@click.option('--conf', default=os.path.join(os.getcwd(),'conf','DATA_CONFIG.yaml'), help='path to DATA_CONFIG.yaml')
@click.argument('name', type=str)
@click.argument('sources', type=str)
@click.argument('destinations', type=str)
def generate_samples(name, sources, destinations, conf):
"""
Download imagery samples for a seeded dataset.
\b
PARAMETERS
----------
NAME: str
The name of the dataset to download.
SOURCES: str
A comma-separated list of sources to download the matching data from. Must be in ['dl','gee','osm','clc']:
dl: DescartesLabs (https://www.descarteslabs.com/)
gee: Google Earth Engine (https://earthengine.google.com/)
osm: OpenStreetMap (https://www.openstreetmap.org/, https://github.com/Lkruitwagen/deepsentinel-osm)
clc: Copernicus Land Cover (https://land.copernicus.eu/pan-european/corine-land-cover, mirrored on DescartesLabs)
DESTINATIONS: str
A comma-separated list of desintations for the generated data. Must be in ['local','gcp','azure']:
local: saved to <data_root>/<name>/
gcp: saved to a Google Cloud Storage Bucket
azure: saved to an Azure Cloud Storage Container
"""
from deepsentinel.utils.sample_generator import SampleDownloader
logger = logging.getLogger('SAMPLE_IMAGERY')
# error check destinations and sources
for source in sources.split(','):
assert (source in ['dl','gee','osm','clc'])
sources = sources.split(',')
for dest in destinations.split(','):
assert (dest in ['local','gcp','azure'])
destinations = destinations.split(',')
logger.info('Sampling imagery with:')
logger.info(f'NAME:{name}')
logger.info(f'SOURCES:{sources}')
logger.info(f'DESTINATIONS:{destinations}')
downloader=SampleDownloader(version=name, destinations=destinations, conf=conf)
if 'dl' in sources:
logger.info('doing dl')
downloader.download_samples_DL()
if 'gee' in sources:
logger.info('gee')
downloader.download_samples_GEE()
if 'clc' in sources:
downloader.download_samples_LC()
if 'osm' in sources:
downloader.download_samples_OSM()
logger.info('DONE!')
@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option('--conf', default=os.path.join(os.getcwd(),'conf','ML_CONFIG.yaml'), help='path to ML_CONFIG.yaml')
@click.option('--observers', default='local,gcp', help='Comma-separated list of observers to add to experiment, from ["local","gcp"]')
@click.option('--name', )
@click.pass_context
def train(ctx, conf, observers, name):
"""
Run the model training scripts with Sacred and a YAML config file.
\b
Any additional parameters can also be specified:
--device=cuda
Nested parameters can be specified like so:
--model_config--VAE--z_dim=16
--model_config--VAE={\"z_dim\":16}
"""
from deepsentinel.main import ex
from sacred.observers import FileStorageObserver
from sacred.observers import GoogleCloudStorageObserver
logger = logging.getLogger('TRAINING')
CONFIG = yaml.load(open(conf,'r'), Loader=yaml.SafeLoader)
logger.info(f'Adding config from {conf}')
ex.add_config(conf)
ctx_conf = {}
for item in ctx.args:
kks,vv = item.split('=')
kks = kks.split('--')[1:]
# cast vv to the type from the nested config
vv_type = type(get_from_dict(CONFIG,kks))
#override special cases
if vv_type==dict:
vv = json.loads(vv)
elif kks[0] in ['pretrain', 'finetune','load_run'] and vv=='None':
vv = None
elif kks[0]=='load_run':
vv = int(vv)
elif kks[0] in ['pretrain','finetune']:
vv = str(vv)
else:
vv = vv_type(vv)
# set in our update dict
try:
set_in_dict(ctx_conf,kks,vv)
except:
for ii_k in range(1,len(kks)):
try:
set_in_dict(ctx_conf,kks[:-ii_k],make_nested_dict(kks[-1*ii_k:],vv))
break
except:
pass
logger.info(f'Adding additional CLI config: {ctx_conf}')
# add observers
if 'local' in observers:
logger.info(f'Adding local observer at {CONFIG["sacred"]["local"]}')
ex.observers.append(FileStorageObserver(CONFIG['sacred']['local']))
if 'gcp' in observers:
logger.info(f'Adding Google Cloud Observer at {CONFIG["sacred"]["gcp_bucket"]}/{CONFIG["sacred"]["gcp_basedir"]}')
ex.observers.append(GoogleCloudStorageObserver(bucket=CONFIG['sacred']['gcp_bucket'], basedir=CONFIG['sacred']['gcp_basedir']))
r = ex.run(config_updates=ctx_conf)
@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option('--ml_conf', default=os.path.join(os.getcwd(),'conf','ML_CONFIG.yaml'), help='path to ML_CONFIG.yaml')
@click.option('--test_conf', default=os.path.join(os.getcwd(),'conf','TEST_CONFIG.yaml'), help='path to ML_CONFIG.yaml')
@click.pass_context
def test(ctx, ml_conf, test_conf):
"""
Run the final models to obtain final test results.
\b
Any additional parameters can also be specified for the test results.
Nested parameters can be specified like so:
--model_config--VAE--z_dim=16
--model_config--VAE={\"z_dim\":16}
"""
from deepsentinel.utils.utils import get_from_dict, set_in_dict, make_nested_dict
from deepsentinel.test import test
CONFIG = yaml.load(open(test_conf,'r'), Loader=yaml.SafeLoader)
ML_CONFIG = yaml.load(open(ml_conf,'r'), Loader=yaml.SafeLoader)
logger=logging.getLogger('testing_runs')
logger.info(f'Using config from {test_conf}')
for item in ctx.args:
kks,vv = item.split('=')
kks = kks.split('--')[1:]
# cast vv to the type from the nested config
vv_type = type(get_from_dict(CONFIG,kks))
#override special cases
if vv_type==dict:
vv = json.loads(vv)
elif kks[0] in ['pretrain', 'finetune','load_run'] and vv=='None':
vv = None
elif kks[0]=='load_run':
vv = int(vv)
elif kks[0] in ['pretrain','finetune']:
vv = str(vv)
else:
vv = vv_type(vv)
# set in our conf dict
set_in_dict(CONFIG,kks,vv)
test(**CONFIG)
@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option('--conf', default=os.path.join(os.getcwd(),'conf','ML_CONFIG.yaml'), help='path to ML_CONFIG.yaml')
@click.pass_context
def mines_coal(ctx, conf):
"""
Use a finetuned model to predict whether mining polygons are coal or not.
\b
Any additional parameters can also be specified:
--device=cuda
Nested parameters can be specified like so:
--model_config--VAE--z_dim=16
--model_config--VAE={\"z_dim\":16}
"""
from deepsentinel.utils.utils import get_from_dict, set_in_dict, make_nested_dict
from deepsentinel.classify_mines import classify_mines, mines_postprocess
CONFIG = yaml.load(open(conf,'r'), Loader=yaml.SafeLoader)
logger=logging.getLogger('run_mines')
logger.info(f'Using config from {conf}')
for item in ctx.args:
kks,vv = item.split('=')
kks = kks.split('--')[1:]
# cast vv to the type from the nested config
vv_type = type(get_from_dict(CONFIG,kks))
#override special cases
if vv_type==dict:
vv = json.loads(vv)
elif kks[0] in ['pretrain', 'finetune','load_run'] and vv=='None':
vv = None
elif kks[0]=='load_run':
vv = int(vv)
elif kks[0] in ['pretrain','finetune']:
vv = str(vv)
else:
vv = vv_type(vv)
# set in our conf dict
set_in_dict(CONFIG,kks,vv)
classify_mines(**CONFIG)
mines_postprocess()
if __name__=="__main__":
cli()