-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_base.py
439 lines (384 loc) · 16.3 KB
/
load_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
"""
Module for loading the transform output into the dataservice. It converts the
merged source data into complete message payloads according to a given API
specification, and then sends those messages to the target server.
"""
import concurrent.futures
import json
import os
import sqlite3
from collections import defaultdict
from pprint import pformat
from threading import Lock, current_thread, main_thread
from urllib.parse import urlparse
from kf_lib_data_ingest.common.concept_schema import CONCEPT
from kf_lib_data_ingest.common.errors import InvalidIngestStageParameters
from kf_lib_data_ingest.common.misc import multisplit
from kf_lib_data_ingest.common.stage import IngestStage
from kf_lib_data_ingest.common.type_safety import (
assert_all_safe_type,
assert_safe_type,
)
from kf_lib_data_ingest.config import DEFAULT_ID_CACHE_FILENAME
from kf_lib_data_ingest.etl.configuration.base_config import (
ConfigValidationError,
)
from kf_lib_data_ingest.etl.configuration.target_api_config import (
TargetAPIConfig,
)
from pandas import DataFrame
count_lock = Lock()
cache_lock = Lock()
class LoadStageBase(IngestStage):
def __init__(
self,
target_api_config_path,
target_url,
entities_to_load,
project_id,
cache_dir=None,
use_async=False,
dry_run=False,
resume_from=None,
clear_cache=False,
):
"""
:param target_api_config_path: path to the target service API config
:type target_api_config_path: str
:param target_url: URL for the target service
:type target_url: str
:param entities_to_load: set of which types of entities to load
:type entities_to_load: list
:param project_id: unique ID of the project being loaded
:type project_id: str
:param cache_dir: where to find the ID cache, defaults to None
:type cache_dir: str, optional
:param use_async: use asynchronous networking, defaults to False
:type use_async: bool, optional
:param dry_run: don't actually transmit, defaults to False
:type dry_run: bool, optional
:param resume_from: Dry run until the designated target ID is seen, and
then switch to normal loading. Does not overrule the dry_run flag.
Value may be a full ID or an initial substring (e.g. 'BS', 'BS_')
:type resume_from: str, optional
:param clear_cache: Clear the identifier cache file before loading,
defaults to False. Equivalent to deleting the file manually. Ignored
when using resume_from, because that needs the cache to be effective.
:type clear_cache: bool, optional
"""
super().__init__(cache_dir or os.getcwd())
self.target_api_config = TargetAPIConfig(target_api_config_path)
self._validate_entities(
entities_to_load,
"Your ingest package config says to load invalid entities:",
)
self.entities_to_load = entities_to_load
self.target_url = target_url
self.dry_run = dry_run
self.resume_from = resume_from
self.project_id = project_id
self.use_async = use_async
self._dry_id = 0
self.uid_cache_filepath = os.path.join(
self.stage_cache_dir,
# Every target gets its own cache because they don't share UIDs
self._clean_name(target_url)
# Every project gets its own cache to compartmentalize internal IDs
+ "_" + project_id + "_" + DEFAULT_ID_CACHE_FILENAME,
)
if not os.path.isfile(self.uid_cache_filepath):
self.logger.info(
"Target identifier cache file not found so a new one will be created:"
f" {self.uid_cache_filepath}"
)
elif clear_cache and not self.resume_from:
os.remove(self.uid_cache_filepath)
self.logger.info(
"Not resuming a previous run, so the identifier cache file at "
f"{self.uid_cache_filepath} has been cleared."
)
# Two-stage (RAM + disk) cache
self.uid_cache = defaultdict(dict)
self.uid_cache_db = sqlite3.connect(
self.uid_cache_filepath,
isolation_level=None,
check_same_thread=False,
)
def _clean_name(self, target_url):
target = urlparse(target_url).netloc or urlparse(target_url).path
return "_".join(multisplit(target, [":", "/"]))
def _validate_entities(self, entities_to_load, msg):
"""
Validate that all entities in entities_to_load are one of the
target concepts specified in the target_api_config.all_targets
"""
target_names = {
t.class_name for t in self.target_api_config.all_targets
}
invalid_ents = set(entities_to_load) - target_names
if invalid_ents:
raise ConfigValidationError(
f"{msg} "
f"{pformat(invalid_ents)}. "
"Valid entities must be one of the target concepts: "
f"{pformat(target_names)} "
f"specified in {self.target_api_config.config_filepath}"
)
def _validate_run_parameters(self, df_dict):
"""
Validate the parameters being passed into the _run method. This method
gets executed before the body of _run is executed.
:param df_dict: a dict of DataFrames, keyed by target concepts defined
in the target_api_config
:type df_dict: dict
"""
try:
assert_safe_type(df_dict, dict)
assert_all_safe_type(df_dict.values(), DataFrame)
except TypeError as e:
raise InvalidIngestStageParameters from e
self._validate_entities(
set(df_dict.keys()) - {"default"},
"Your transform module output has invalid keys:",
)
def _prime_uid_cache(self, entity_type):
"""
Make sure that the backing cache database table exists and that the RAM
store is populated.
:param entity_type: the name of this type of entity
:type entity_type: str
"""
if entity_type not in self.uid_cache:
# Create table in DB first if necessary
self.uid_cache_db.execute(
f'CREATE TABLE IF NOT EXISTS "{entity_type}"'
" (unique_id TEXT PRIMARY KEY, target_id TEXT);"
)
# Populate RAM cache from DB
for unique_id, target_id in self.uid_cache_db.execute(
f'SELECT unique_id, target_id FROM "{entity_type}";'
):
self.uid_cache[entity_type][unique_id] = target_id
def _get_target_id_from_key(self, entity_type, entity_key):
"""
Retrieve the target service ID for a given source unique key.
:param entity_type: the name of this type of entity
:type entity_type: str
:param entity_key: source unique key for this entity
:type entity_key: str
"""
with cache_lock:
self._prime_uid_cache(entity_type)
return self.uid_cache[entity_type].get(entity_key)
def _store_target_id_for_key(
self, entity_type, entity_key, target_id, no_db
):
"""
Cache the relationship between a source unique key and its corresponding
target service ID.
:param entity_type: the name of this type of entity
:type entity_type: str
:param entity_key: source unique key for this entity
:type entity_key: str
:param target_id: target service ID for this entity
:type target_id: str
:param no_db: only store in the RAM cache, not in the db
:type no_db: bool
"""
with cache_lock:
self._prime_uid_cache(entity_type)
if self.uid_cache[entity_type].get(entity_key) != target_id:
self.uid_cache[entity_type][entity_key] = target_id
if not no_db:
self.uid_cache_db.execute(
f'INSERT OR REPLACE INTO "{entity_type}"'
" (unique_id, target_id)"
" VALUES (?,?);",
(entity_key, target_id),
)
def _get_target_id_from_record(self, entity_class, record):
"""
Find the target service ID for the given record and entity class.
:param entity_class: one of the classes contained in the all_targets list
:type entity_class: class
:param record: a record of extracted data
:type record: dict
:return: the target service ID
:rtype: str
"""
raise NotImplementedError()
def _do_target_submit(self, entity_class, body):
"""Shim for target API submission across loader versions"""
raise NotImplementedError()
def _do_target_get_key(self, entity_class, record):
"""Shim for target API key building across loader versions"""
raise NotImplementedError()
def _do_target_get_entity(self, entity_class, record, keystring):
"""Shim for target API entity building across loader versions"""
raise NotImplementedError()
def _read_output(self):
pass # TODO
def _write_output(self, output):
pass # TODO
def _load_entity(self, entity_class, record):
"""
Prepare a single entity for submission to the target service.
"""
try:
key_components = self._do_target_get_key(entity_class, record)
except Exception:
# no new key, no new entity
key_components = None
if not key_components:
self.logger.debug(
f"Skip {entity_class.class_name}. Missing key components. "
f"Failed to construct unique key from record:"
f"\n{pformat(record)}"
)
return
unique_key = str(key_components)
if unique_key in self.seen_entities[entity_class.class_name]:
# no new key, no new entity
self.logger.debug(
f"Skip {entity_class.class_name}. Duplicate record found in "
f"data:\n{record}"
)
return
self.seen_entities[entity_class.class_name].add(unique_key)
target_id = self._get_target_id_from_record(entity_class, record)
method = "UPDATE" if target_id else "CREATE"
try:
body = self._do_target_get_entity(entity_class, record, unique_key)
except Exception:
self.logger.info(
f"Failed to build {entity_class.class_name} from record {record}"
)
raise
if current_thread() is not main_thread():
current_thread().name = f"{entity_class.class_name} {unique_key}"
if self.resume_from:
if not target_id:
raise InvalidIngestStageParameters(
"Use of the resume_from flag requires having already"
" cached target IDs for all prior entities. The resume"
" target has not yet been reached, and no cached ID"
f" was found for this entity body:\n{pformat(body)}"
)
elif target_id.startswith(self.resume_from):
self.logger.info(
f"Found resume target '{self.resume_from}'. Resuming"
" normal load."
)
self.dry_run = False
self.resume_from = None
msg = f"{method} {entity_class.class_name} ({unique_key})"
if target_id:
msg = f"{msg} [{target_id}]"
if self.dry_run:
self.logger.debug(f"Request body preview:\n{pformat(body)}")
msg = f"DRY RUN - {msg}"
if not target_id:
self._dry_id += 1
target_id = f"DRY_{entity_class.class_name}_{self._dry_id}"
else:
# send to the target service
target_id = self._do_target_submit(entity_class, body)
msg = f"{msg} --> {target_id}"
# cache source_ID:target_ID lookup
self._store_target_id_for_key(
entity_class.class_name, unique_key, target_id, self.dry_run
)
# log action
with count_lock:
self.sent_messages.append(
{
"type": entity_class.class_name,
"method": method,
"body": body,
}
)
self.counts[entity_class.class_name][method] += 1
self.logger.info(
f"{msg} (#{sum(self.counts[entity_class.class_name].values())})"
)
def _postrun_validation(self, validation_mode=None, report_kwargs={}):
# Override implemented base class method because we don't need to
# do any validation on this stage's output
pass
def _run(self, transform_output):
"""
Load Stage internal entry point. Called by IngestStage.run
:param transform_output: Output data structure from the Transform stage
:type transform_output: dict
"""
self.counts = defaultdict(dict)
self.seen_entities = defaultdict(set)
if self.dry_run:
self.logger.info(
"DRY RUN mode is ON. No entities will be loaded into the "
"target service."
)
self.resume_from = None
elif self.resume_from:
self.logger.info(
f"Will dry run until '{self.resume_from}' and then resume"
" loading from that entity."
)
self.dry_run = True
# Loop through all target concepts
self.sent_messages = []
try:
for entity_class in self.target_api_config.all_targets:
if entity_class.class_name not in self.entities_to_load:
self.logger.info(
f"Skipping load of {entity_class.class_name}. Not "
"included in ingest package config."
)
continue
self.logger.info(f"Begin loading {entity_class.class_name}")
if entity_class.class_name in transform_output:
t_key = entity_class.class_name
else:
t_key = "default"
# convert df to list of dicts
transformed_records = transform_output[t_key].to_dict("records")
if hasattr(entity_class, "transform_records_list"):
transformed_records = entity_class.transform_records_list(
transformed_records
)
# guarantee existence of the project unique key column
for r in transformed_records:
r[CONCEPT.PROJECT.ID] = self.project_id
self.counts[entity_class.class_name]["CREATE"] = 0
self.counts[entity_class.class_name]["UPDATE"] = 0
if self.use_async:
ex = concurrent.futures.ThreadPoolExecutor()
futures = []
self.logger.info(
f"Reading {len(transformed_records)} rows in '{t_key}' table."
)
for record in transformed_records:
if self.use_async and not self.resume_from:
futures.append(
ex.submit(self._load_entity, entity_class, record)
)
else:
self._load_entity(entity_class, record)
if self.use_async:
for f in concurrent.futures.as_completed(futures):
f.result()
ex.shutdown()
self.logger.info(f"End loading {entity_class.class_name}")
finally:
target = self._clean_name(self.target_url)
json_out = os.path.join(
self.stage_cache_dir, f"SentMessages_{target}.json"
)
with open(json_out, "w") as jo:
json.dump(self.sent_messages, jo, indent=2)
if self.resume_from:
self.logger.warning(
f"⚠️ Could not find resume_from target '{self.resume_from}'! "
"Nothing was actually loaded into the target service."
)
self.logger.info(f"Load Summary:\n{pformat(dict(self.counts))}")