-
Notifications
You must be signed in to change notification settings - Fork 69
/
elasticsearch.py
266 lines (218 loc) · 8.93 KB
/
elasticsearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2020 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Elasticsearch source dumper.
Dumper used to dump/load an the body of an Elasticsearch document.
"""
from copy import deepcopy
from datetime import date, datetime
from uuid import UUID
import arrow
import pytz
from sqlalchemy.sql.sqltypes import JSON, Boolean, DateTime, Integer, String, \
Text
from sqlalchemy.sql.type_api import Variant
from sqlalchemy_utils.types.uuid import UUIDType
from ..systemfields.model import ModelField
from .base import Dumper
class ElasticsearchDumperExt:
"""Interface for Elasticsearch dumper extensions."""
def dump(self, record, data):
"""Dump the data."""
def load(self, data, record_cls):
"""Load the data.
Reverse the changes made by the dump method.
"""
class ElasticsearchDumper(Dumper):
"""Elasticsearch source dumper."""
def __init__(self, extensions=None, model_fields=None):
"""."""
self._extensions = extensions or []
self._model_fields = {
'id': ('uuid', UUID),
'version_id': ('version_id', int),
'created': ('created', datetime),
'updated': ('updated', datetime),
# is_deleted is purposely not added (deleted record isnt indexed)
}
self._model_fields.update(model_fields or {})
@staticmethod
def _sa_type(model_cls, model_field_name):
"""Introspection of SQLAlchemy column data type.
:param model_cls: The SQLALchemy model.
:param model_field_name: The name of the field on the SQLAlchemy model.
"""
try:
sa_type = \
model_cls.__table__.columns[model_field_name].type
sa_type_class = sa_type.__class__
# Deal with variant class
if issubclass(sa_type_class, Variant):
sa_type = sa_type.impl
sa_type_class = sa_type.__class__
if issubclass(sa_type_class, DateTime):
return datetime
elif issubclass(sa_type_class, Boolean):
return bool
elif issubclass(sa_type_class, Integer):
return int
elif issubclass(sa_type_class, UUIDType):
return UUID
elif issubclass(sa_type_class, String):
return str
elif issubclass(sa_type_class, JSON):
return dict
return None
except (KeyError, AttributeError):
return None
@staticmethod
def _serialize(value, dump_type):
"""Serialize a value according to it's data type.
:param value: Value to serialize.
:param dump_type: Data type use for serialization (supported: str, int,
bool, float, datetime, date, uuid).
"""
if value is None:
return value
if dump_type in (datetime, ):
return pytz.utc.localize(value).isoformat()
elif dump_type in (UUID, ):
return str(value)
elif dump_type is not None:
return dump_type(value)
return value
@staticmethod
def _deserialize(value, dump_type):
"""Deserialize a value according to it's data type.
:param value: Value to deserialize.
:param dump_type: Data type use for deserialization (supported: str,
int, bool, float, datetime, date, uuid).
"""
if value is None:
return value
if dump_type in (datetime, ):
return arrow.get(value).datetime.replace(tzinfo=None)
elif dump_type in (UUID, ):
return dump_type(value)
elif dump_type is not None:
return dump_type(value)
return value
def _dump_model_field(self, record, model_field_name, dump, dump_key,
dump_type):
"""Helper method to dump model fields.
:param record: The record being dumped.
:param model_field_name: The name of the SQLAlchemy model field on the
record's model.
:param dump: The dictionary of the current dump.
:param dump_key: The key to use in the dump.
:param dump_type: The data type used for serialization.
"""
# If model is not defined, we dump None into the field value.
if record.model is None:
dump[dump_key] = None
return
# Retrieve value of the field on the model.
val = getattr(record.model, model_field_name)
# Determine data type if not set.
if dump_type is None:
dump_type = self._sa_type(record.model_cls, model_field_name)
# Serialize (according to data type) and set value in output on the
# specified key.
dump[dump_key] = self._serialize(val, dump_type)
def _load_model_field(self, record_cls, model_field_name, dump, dump_key,
load_type):
"""Helper method to load model fields from dump.
:param record_cls: The record class being used for loading.
:param model_field_name: The name of the SQLAlchemy model field on the
record's model.
:param dump: The dictionary of the dump.
:param dump_key: The key to use in the dump.
:param dump_type: The data type used for deserialization.
"""
# Retrieve the value
val = dump.pop(dump_key)
# Return None values immediately.
if val is None:
return val
# Determine dump data type if not provided
if load_type is None:
load_type = self._sa_type(record_cls.model_cls, model_field_name)
# Deserialize the value
return self._deserialize(val, load_type)
@staticmethod
def _iter_modelfields(record_cls):
"""Internal helper method to extract all model fields."""
for attr_name in dir(record_cls):
systemfield = getattr(record_cls, attr_name)
if isinstance(systemfield, ModelField):
if systemfield.dump:
yield systemfield
def dump(self, record, data):
"""Dump a record.
The method adds the following keys (if the record has an associated
model):
- ``uuid`` - UUID of the record.
- ``version_id`` - the revision id of the record.
- ``created`` - Creation timestamp in UTC.
- ``updated`` - Modification timestamp in UTC.
"""
# Copy data first, otherwise we modify the record.
dump_data = super().dump(record, data)
# Dump model fields explicitly requested
it = self._model_fields.items()
for model_field_name, (dump_key, dump_type) in it:
self._dump_model_field(
record,
model_field_name,
dump_data,
dump_key,
dump_type,
)
# Dump model fields defined as system fields.
for systemfield in self._iter_modelfields(record.__class__):
self._dump_model_field(
record,
systemfield.model_field_name,
dump_data,
systemfield.dump_key,
systemfield.dump_type,
)
# Allow extensions to integrate as well.
for e in self._extensions:
e.dump(record, dump_data)
return dump_data
def load(self, dump_data, record_cls):
"""Load a record from an Elasticsearch document source.
The method reverses the changes made during the dump. If a model was
associated, a model will also be initialized.
.. warning::
The model is not added to the SQLAlchemy session. If you plan on
using the model, you must merge it into the session using e.g.:
.. code-block:: python
db.session.merge(record.model)
"""
# First allow extensions to modify the data.
for e in self._extensions:
e.load(dump_data, record_cls)
# Load explicitly defined model fields.
model_data = {}
it = self._model_fields.items()
for model_field_name, (dump_key, load_type) in it:
model_data[model_field_name] = self._load_model_field(
record_cls, model_field_name, dump_data, dump_key, load_type)
# Load model fields defined as system fields
for systemfield in self._iter_modelfields(record_cls):
model_data[systemfield.model_field_name] = self._load_model_field(
record_cls, systemfield.model_field_name, dump_data,
systemfield.dump_key, systemfield.load_type)
# Initialize model if an id was provided.
if model_data.get('id') is not None:
model_data['data'] = dump_data
model = record_cls.model_cls(**model_data)
else:
model = None
return record_cls(dump_data, model=model)