-
Notifications
You must be signed in to change notification settings - Fork 1
/
loadNeuroQueryTermsSapBert.py
executable file
·148 lines (127 loc) · 4.69 KB
/
loadNeuroQueryTermsSapBert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
import argparse
from xml.etree import ElementTree as ET
import config
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
from pathlib import Path
from elasticsearch import Elasticsearch
import sys
########################################################################################
#
# This code creates connects to an ElasticSearch instance and creates an index specified
# with the indexName parameter. It then reads file in the directory tree specified by
# the inputDir parameter and extracts information. In this case, the file is going to be
# a set of terms used by the NeuroQuery implementation,
# The information is as an input to the Biobert
# sentence embedding module and the resulting vector is passed as part of the input
# from documents in the index. See
#
# https://blog.accubits.com/vector-similarity-search-using-elasticsearch/
#
# for some insight into this process.
########################################################################################
def main(args):
# Extract params from config
indexName = config.NEUROBRIDGE_ELASTIC_INDEX
termFile = config.NEUROQUERY_TERM_FILE
shards = "1"
# Connect to elasticSearch
esConn = connectElastic(config.ELASTIC_IP, config.ELASTIC_PORT)
# Create the index. Note it's currently OK if the index is already there. We should
# probably add a command line option to delete the index if it's already there
createIndex(indexName, shards, esConn)
insertDataIntoIndex(termFile, indexName, shards, esConn)
def insertDataIntoIndex(termFile, indexName, shards, esConn):
# First question: has the data already been loaded?
res = esConn.indices.refresh(indexName)
res = esConn.cat.count(indexName, params={"format": "json"})
nData = (res[0]['count'])
if int(nData) > 0:
print (f"{nData} data items already loaded")
sys.exit(0)
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
rowId = 1
terms = open(termFile, 'r')
for line in terms:
cleanLine = line.strip()
toks = tokenizer.batch_encode_plus([cleanLine],
padding="max_length",
max_length=25,
truncation=True,
return_tensors="pt")
output = model(**toks)
cls_rep = output[0][:,0,:]
print(type(cls_rep))
embeddingArray = cls_rep.detach().numpy()
print(type(embeddingArray))
print(embeddingArray)
insertBody = {'term_name': cleanLine,
'term': cleanLine,
'term_vec': embeddingArray[0],
'row_id': rowId }
rowId += 1
esConn.index(index=indexName, body=insertBody)
print(f"number of rows inserted is {rowId - 1}")
terms.close()
def connectElastic(ip, port):
# Connect to an elasticsearch node with the given ip and port
esConn = None
print(f"port is {port}")
print(f"host is {ip}")
esConn = Elasticsearch([{"host": ip, "port": port}])
try:
if esConn.ping():
print("Connected to elasticsearch...")
else:
print("Elasticsearch connection error..")
print(esConn)
sys.exit(1)
except:
print("error caught")
return esConn
def createIndex(indexName, shards, esConn):
# Define the index mapping
indexBodyString = """{
"mappings": {
"properties": {
"term_name": {
"type": "text"
},
"term": {
"type": "text"
},
"term_vec": {
"type": "dense_vector",
"dims": 768
},
"row_id": {
"type": "long"
}
}
},
"settings": {
"number_of_shards": nShards
"number_of_replicas": 0
}
}
"""
indexBody = indexBodyString.replace("nShards", shards)
print(f"indexBody: {indexBody}")
try:
# Create the index if not exists
if not esConn.indices.exists(indexName):
# Ignore 400 means to ignore "Index Already Exist" error.
esConn.indices.create(
index=indexName, body=indexBody # ignore=[400, 404]
)
print("Created Index: ", indexName)
else:
print(f"Index {indexName} already exists")
except Exception as ex:
print(str(ex))
if __name__ == '__main__':
args = "none"
main(args)