-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge.py
165 lines (143 loc) · 9.35 KB
/
merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import json
import os
from collections import defaultdict
from typing import Tuple
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def merge_example_queries():
for query_dir_name in [dir_name for dir_name in os.listdir(f"{CURRENT_DIR}/example_result_sets")
if not dir_name.startswith(".")]:
print(f"\nMerging responses for query '{query_dir_name}'...")
# Load the original query graph
query_dir_path = f"{CURRENT_DIR}/example_result_sets/{query_dir_name}"
with open(f"{query_dir_path}/qg.json") as query_graph_file:
query_graph = json.load(query_graph_file)
# Load ARA responses
ara_response_file_names = [file_name for file_name in os.listdir(f"{query_dir_path}/ara_responses")
if file_name.endswith(".json")]
responses = dict()
for ara_file_name in ara_response_file_names:
ara_name = ara_file_name.replace(".json", "")
with open(f"{query_dir_path}/ara_responses/{ara_name}.json") as response_file:
response = json.load(response_file)
responses[ara_name] = response
merged_response, stats_report = merge(responses, query_graph)
# Save the merged TRAPI response
with open(f"{query_dir_path}/merged_response.json", "w+") as merged_response_file:
json.dump(merged_response, merged_response_file, indent=2)
# Save a little report of result/KG counts
with open(f"{query_dir_path}/report.json", "w+") as report_file:
json.dump(stats_report, report_file, indent=2)
def merge(responses: dict, query_graph: dict) -> Tuple[dict, dict]:
"""
Merges the 'results' and 'knowledge_graph' of the input TRAPI responses.
:param responses: This should be a dictionary where keys are ARA's names and values are the TRAPI responses from
those ARAs that you want to merge. (Also works with KPs.) (Note: technically you could use whatever name you
want for each response; doesn't have to be an ARA or KP's name, though that often makes sense.)
:param query_graph: This should be the original query graph submitted that produced the responses being merged.
:return: A merged TRAPI response and a small report with result/node/edge counts before and after merging.
"""
# Figure out which qnodes are 'required' (i.e., should be used for constructing the result hash key)
required_qnode_keys = {qnode_key for qnode_key, qnode in query_graph["nodes"].items() if not qnode.get("is_set")}
required_qnode_keys_sorted = sorted(list(required_qnode_keys))
# For each response, organize the results by hash keys and merge the KG into an overarching KG
results_by_hash_key = defaultdict(list)
pre_merging_counts = {"results": 0, "nodes": 0, "edges": 0}
merged_kg = {"nodes": dict(), "edges": dict()}
for ara_name, response in responses.items():
print(f"Starting to process {ara_name} result set...")
results = response["message"]["results"]
kg = response["message"]["knowledge_graph"]
print(f" {ara_name} response contains {len(results)} results, "
f"{len(kg['nodes'])} KG nodes, {len(kg['edges'])} KG edges")
pre_merging_counts["results"] += len(results)
pre_merging_counts["nodes"] += len(kg['nodes'])
pre_merging_counts["edges"] += len(kg['edges'])
# Organize results by their hash keys
for result in results:
qnode_keys_fulfilled_in_result = set(result["node_bindings"])
if not required_qnode_keys.issubset(qnode_keys_fulfilled_in_result):
print(f" WARNING: Found a result that doesn't fulfill all required qnode keys! Skipping...")
else:
merge_curies = []
for qnode_key in required_qnode_keys_sorted:
nodes_fulfilling_this_qnode = {binding["id"] for binding in result["node_bindings"][qnode_key]}
# There should only be one node fulfilling this qnode since is_set=False
if len(nodes_fulfilling_this_qnode) > 1:
print(f" WARNING: Result has more than one node fulfilling {qnode_key}, "
f"which has is_set=False: {nodes_fulfilling_this_qnode}")
# Note: With TRAPI 1.3, multiple nodes WILL be able to fulfill an is_set=False node within
# a single result IF they all have the same parent ID mapping ('query_id'); in that case
# the merge curie is the 'query_id' (not implemented since TRAPI 1.3 is still in dev)
merge_curie = list(nodes_fulfilling_this_qnode)[0]
merge_curies.append(merge_curie)
result_hash_key = "--".join(merge_curies)
results_by_hash_key[result_hash_key].append(result)
# Merge this response's answer KG into the overarching KG
for node_key, node in kg["nodes"].items():
# Note: Node attributes should probably be merged; not doing that here for simplicity's sake
if node_key not in merged_kg["nodes"]:
merged_kg["nodes"][node_key] = node
for edge_key, edge in kg["edges"].items():
# Note: We don't merge any edges here for simplicity's sake; other edge merging approaches could be taken
# Also: Technically different ARAs/KPs could use the same edge keys to refer to different edges; watch out
if edge_key not in merged_kg["edges"]:
merged_kg["edges"][edge_key] = edge
# Then go through and merge all results with equivalent hash keys; we want the UNION of nodes
merged_results = []
print(f"Merging result sets... before merging there are {pre_merging_counts['results']} "
f"results, {pre_merging_counts['nodes']} nodes, and {pre_merging_counts['edges']} edges.")
for hash_key, results in results_by_hash_key.items():
merged_result = {"node_bindings": defaultdict(list), "edge_bindings": defaultdict(list)}
for result in results:
for qnode_key, node_bindings in result["node_bindings"].items():
for node_binding in node_bindings:
merged_result["node_bindings"][qnode_key].append(node_binding)
# Note: Node bindings can have 'attributes', which perhaps should be merged here? Ignoring for now..
for qedge_key, edge_bindings in result["edge_bindings"].items():
for edge_binding in edge_bindings:
# We choose to retain ALL edges (could sub in different edge merging strategy here)
merged_result["edge_bindings"][qedge_key].append(edge_binding)
# Note: Optionally might have some way of also merging result score?
merged_results.append(merged_result)
percent_results = round((len(merged_results) / pre_merging_counts["results"]) * 100)
percent_nodes = round((len(merged_kg["nodes"]) / pre_merging_counts["nodes"]) * 100)
percent_edges = round((len(merged_kg["edges"]) / pre_merging_counts["edges"]) * 100)
# Get rid of duplicate node bindings
for merged_result in merged_results:
for qnode_key, node_bindings in merged_result["node_bindings"].items():
deduplicated_node_bindings = []
bound_curies = set()
for node_binding in node_bindings:
node_id = node_binding["id"]
if node_id not in bound_curies:
deduplicated_node_bindings.append(node_binding)
bound_curies.add(node_id)
# Figure out the 'essence' of each result (helpful for the ARAX UI)
unpinned_required_qnodes = {qnode_key for qnode_key in required_qnode_keys
if not query_graph["nodes"][qnode_key].get("ids")}
if len(unpinned_required_qnodes) > 1:
print(f"Hmm, more than one potential essence node. Will randomly choose out of the "
f"{len(unpinned_required_qnodes)} candidates.")
essence_qnode_key = list(unpinned_required_qnodes)[0]
print(f"Essence qnode is {essence_qnode_key}")
for result in merged_results:
essence_node_key = result["node_bindings"][essence_qnode_key][0]["id"]
essence_node_name = merged_kg["nodes"][essence_node_key].get("name", essence_node_key)
result["essence"] = essence_node_name
print(f"Done merging responses! There are {len(merged_results)} results after merging "
f"({percent_results}%). Merged KG contains {len(merged_kg['nodes'])} nodes ({percent_nodes}%) and "
f"{len(merged_kg['edges'])} edges ({percent_edges}%).")
merged_response = {"message": {"results": merged_results,
"query_graph": query_graph,
"knowledge_graph": merged_kg}}
stats_report = {"pre_merging": {"results": pre_merging_counts["results"],
"nodes": pre_merging_counts["nodes"],
"edges": pre_merging_counts["edges"]},
"post_merging": {"results": f"{len(merged_results)} ({percent_results}%)",
"nodes": f"{len(merged_kg['nodes'])} ({percent_nodes}%)",
"edges": f"{len(merged_kg['edges'])} ({percent_edges}%)"}}
return merged_response, stats_report
def main():
merge_example_queries()
if __name__ == "__main__":
main()