-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multi-sv detection #2510
Changes from 8 commits
176dec4
03519ee
d511dd8
431b6a8
f6ada11
fe937ed
230e5e1
f8805cc
46ac9eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,10 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Managing the embeddings.""" | ||
from dataclasses import dataclass | ||
import logging | ||
import os | ||
import re | ||
from typing import Dict, List, Union | ||
|
||
from datasets import load_dataset | ||
|
@@ -22,11 +24,60 @@ | |
from sentence_transformers.util import semantic_search | ||
import torch | ||
|
||
from nl_server import query_util | ||
import nl_server.gcs as gcs | ||
from server.lib.nl import utils | ||
|
||
TEMP_DIR = '/tmp/' | ||
MODEL_NAME = 'all-MiniLM-L6-v2' | ||
|
||
# A value higher than the highest score. | ||
_HIGHEST_SCORE = 1.0 | ||
_INIT_SCORE = (_HIGHEST_SCORE + 0.1) | ||
|
||
# Scores below this are ignored. | ||
_SV_SCORE_THRESHOLD = 0.5 | ||
|
||
# If the difference between successive scores exceeds this threshold, then SVs at | ||
# the lower score and below are ignored. | ||
_MULTI_SV_SCORE_DIFFERENTIAL = 0.05 | ||
|
||
_NUM_CANDIDATES_PER_NSPLIT = 3 | ||
|
||
|
||
# List of SV candidates, along with scores. | ||
@dataclass | ||
class VarCandidates: | ||
# The below are sorted and parallel lists. | ||
svs: List[str] | ||
scores: List[float] | ||
sv2sentences: Dict[str, List[str]] | ||
|
||
|
||
# One part of a single multi-var candidate and its | ||
# associated SVs and scores. | ||
@dataclass | ||
class MultiVarCandidatePart: | ||
query_part: str | ||
svs: List[str] | ||
scores: List[float] | ||
|
||
|
||
# One multi-var candidate containing multiple parts. | ||
@dataclass | ||
class MultiVarCandidate: | ||
parts: List[MultiVarCandidatePart] | ||
# Aggregate score | ||
aggregate_score: float | ||
# Is this candidate based on a split computed from delimiters? | ||
delim_based: bool | ||
|
||
|
||
# List of multi-var candidates. | ||
@dataclass | ||
class MultiVarCandidates: | ||
candidates: List[MultiVarCandidate] | ||
|
||
|
||
class Embeddings: | ||
"""Manages the embeddings.""" | ||
|
@@ -80,52 +131,201 @@ def get_embedding_at_index(self, index: int) -> List[float]: | |
def get_embedding(self, query: str) -> List[float]: | ||
return self.model.encode(query).tolist() | ||
|
||
def detect_svs(self, query: str) -> Dict[str, Union[Dict, List]]: | ||
query_embeddings = self.model.encode([query]) | ||
# | ||
# Given a list of queries, searches the in-memory embeddings index | ||
# and returns a map of candidates keyed by input queries. | ||
# | ||
def _search_embeddings(self, queries: List[str]) -> Dict[str, VarCandidates]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: why do the refactor of this function in the same PR as the introduction of a new function for detecting multi svs? Is there some functionality that was missing which this refactor addresses? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The refactor takes multiple queries and returns results, which was needed for the multi SV case. |
||
query_embeddings = self.model.encode(queries) | ||
hits = semantic_search(query_embeddings, self.dataset_embeddings, top_k=20) | ||
|
||
# Note: multiple results may map to the same DCID. As well, the same string may | ||
# map to multiple DCIDs with the same score. | ||
sv2score = {} | ||
# Also track the sv to index so that embeddings can later be retrieved. | ||
sv2index = {} | ||
# Also add the full list of SVs and sentences that matched (for debugging). | ||
all_svs_sentences: Dict[str, List[str]] = {} | ||
for e in hits[0]: | ||
for d in self.dcids[e['corpus_id']].split(','): | ||
s = e['score'] | ||
ind = e['corpus_id'] | ||
sentence = "" | ||
try: | ||
sentence = self.sentences[e['corpus_id']] + f" ({s})" | ||
except Exception as exp: | ||
logging.info(exp) | ||
# Prefer the top score. | ||
if d not in sv2score: | ||
sv2score[d] = s | ||
sv2index[d] = ind | ||
|
||
# Add to the debug map anyway. | ||
existing_sentences = [] | ||
if d in all_svs_sentences: | ||
existing_sentences = all_svs_sentences[d] | ||
|
||
if sentence not in existing_sentences: | ||
existing_sentences.append(sentence) | ||
all_svs_sentences[d] = existing_sentences | ||
|
||
# Sort by scores | ||
sv2score_sorted = sorted(sv2score.items(), | ||
key=lambda item: item[1], | ||
reverse=True) | ||
svs_sorted = [k for (k, _) in sv2score_sorted] | ||
scores_sorted = [v for (_, v) in sv2score_sorted] | ||
|
||
sv_index_sorted = [sv2index[k] for (k, _) in sv2score_sorted] | ||
# A map from input query -> SV DCID -> matched sentence -> score for that match | ||
query2sv2sentence2score: Dict[str, Dict[str, Dict[str, float]]] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have a comment with an example of query2sv2sentence2score? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Described it. |
||
# A map from input query -> SV DCID -> highest matched score | ||
query2sv2score: Dict[str, Dict[str, float]] = {} | ||
for i, hit in enumerate(hits): | ||
q = queries[i] | ||
query2sv2score[q] = {} | ||
query2sv2sentence2score[q] = {} | ||
for ent in hit: | ||
score = ent['score'] | ||
for dcid in self.dcids[ent['corpus_id']].split(','): | ||
# Prefer the top score. | ||
if dcid not in query2sv2score[q]: | ||
query2sv2score[q][dcid] = score | ||
query2sv2sentence2score[q][dcid] = {} | ||
|
||
if ent['corpus_id'] >= len(self.sentences): | ||
continue | ||
sentence = self.sentences[ent['corpus_id']] | ||
query2sv2sentence2score[q][dcid][sentence] = score | ||
|
||
query2result: Dict[str, VarCandidates] = {} | ||
|
||
# Go over the map and prepare parallel lists of | ||
# SVs and scores in query2result. | ||
for q, sv2score in query2sv2score.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have a comment for this and the next large for loop on what it processes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
sv2score_sorted = [(k, v) for ( | ||
k, | ||
v) in sorted(sv2score.items(), key=lambda item: item[1], reverse=True) | ||
] | ||
svs = [k for (k, _) in sv2score_sorted] | ||
scores = [v for (_, v) in sv2score_sorted] | ||
query2result[q] = VarCandidates(svs=svs, scores=scores, sv2sentences={}) | ||
|
||
# Go over the results and prepare the sv2sentences map in | ||
# query2result. | ||
for q, sv2sentence2score in query2sv2sentence2score.items(): | ||
query2result[q].sv2sentences = {} | ||
for sv, sentence2score in sv2sentence2score.items(): | ||
query2result[q].sv2sentences[sv] = [] | ||
for sentence, score in sorted(sentence2score.items(), | ||
key=lambda item: item[1], | ||
reverse=True): | ||
score = round(score, 4) | ||
query2result[q].sv2sentences[sv].append(sentence + f' ({score})') | ||
|
||
return query2result | ||
|
||
# | ||
# The main entry point to detect SVs. | ||
# | ||
def detect_svs(self, orig_query: str) -> Dict[str, Union[Dict, List]]: | ||
# Remove all stop-words. | ||
query_monovar = utils.remove_stop_words(orig_query, | ||
query_util.ALL_STOP_WORDS) | ||
|
||
# Search embeddings for a single SV. | ||
result_monovar = self._search_embeddings([query_monovar])[query_monovar] | ||
|
||
# Try to detect multiple SVs. Use the original query so that | ||
# the logic can rely on stop-words like `vs`, `and`, etc as hints | ||
# for SV delimiters. | ||
result_multivar = self._detect_multiple_svs(orig_query) | ||
|
||
# TODO: Rename SV_to_Sentences for consistency. | ||
return { | ||
'SV': svs_sorted, | ||
'CosineScore': scores_sorted, | ||
'EmbeddingIndex': sv_index_sorted, | ||
'SV_to_Sentences': all_svs_sentences, | ||
'SV': result_monovar.svs, | ||
'CosineScore': result_monovar.scores, | ||
'SV_to_Sentences': result_monovar.sv2sentences, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use CamelCase too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is from before, and will touch a few unrelated places, so can change in a follow on PR. Added TODO |
||
'MultiSV': _multivar_candidates_to_dict(result_multivar) | ||
} | ||
|
||
# | ||
# Detects one or more SVs from the query. | ||
# TODO: Fix the query upstream to ensure the punctuations aren't stripped. | ||
# | ||
def _detect_multiple_svs(self, query: str) -> MultiVarCandidates: | ||
# | ||
# Prepare a combination of query-sets. | ||
# | ||
querysets = query_util.prepare_multivar_querysets(query) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading the comments below about a queryset, wouldn't it be better to use this combinatorial API/function built in to python? https://www.geeksforgeeks.org/itertools-combinations-module-python-print-possible-combinations/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea. doing in the follow on PR. |
||
|
||
result = MultiVarCandidates(candidates=[]) | ||
|
||
# Make a unique list of query strings | ||
all_queries = set() | ||
for qs in querysets: | ||
for c in qs.combinations: | ||
for p in c.parts: | ||
all_queries.add(p) | ||
if not all_queries: | ||
return result | ||
|
||
query2result = self._search_embeddings(list(all_queries)) | ||
|
||
# | ||
# A queryset is the set of all combinations of query | ||
# splits of a given length. For example, a query like | ||
# "hispanic women phd" may have a queryset for a | ||
# 2-way split, like: | ||
# QuerySet(nsplits=2, | ||
# combinations=[ | ||
# ['hispanic women', 'phd'], | ||
# ['hispanic', 'women phd'], | ||
# ]) | ||
# The 3-way split has only one combination. | ||
# | ||
# We take the average score from the top SV from the query-parts in a | ||
# queryset (ignoring any queryset with a score below threshold). Then | ||
# sort all candidates by that score. | ||
# | ||
# TODO: Come up with a better ranking function. | ||
# | ||
for qs in querysets: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some examples of "querysets" is helpful, since its fields are very deep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, PTAL. |
||
candidates: List[MultiVarCandidate] = [] | ||
for c in qs.combinations: | ||
if not c or not c.parts: | ||
continue | ||
|
||
total = 0 | ||
candidate = MultiVarCandidate(parts=[], | ||
delim_based=qs.delim_based, | ||
aggregate_score=-1) | ||
lowest = _INIT_SCORE | ||
for q in c.parts: | ||
r = query2result.get( | ||
q, VarCandidates(svs=[], scores=[], sv2sentences={})) | ||
part = MultiVarCandidatePart(query_part=q, svs=[], scores=[]) | ||
score = 0 | ||
if r.svs: | ||
# Pick the top-K SVs. | ||
limit = _pick_top_k(r) | ||
if limit > 0: | ||
part.svs = r.svs[:limit] | ||
part.scores = [round(s, 4) for s in r.scores[:limit]] | ||
score = r.scores[0] | ||
|
||
if score < lowest: | ||
lowest = score | ||
total += score | ||
candidate.parts.append(part) | ||
|
||
if lowest < _SV_SCORE_THRESHOLD: | ||
# A query-part's best SV did not cross our score threshold, | ||
# so drop this candidate. | ||
continue | ||
|
||
# The candidate level score is the average. | ||
candidate.aggregate_score = total / len(c.parts) | ||
candidates.append(candidate) | ||
|
||
if candidates: | ||
# Pick the top candidate. | ||
candidates.sort(key=lambda c: c.aggregate_score, reverse=True) | ||
# Pick upto some number of candidates. Could be just 1 | ||
# eventually. | ||
result.candidates.extend(candidates[:_NUM_CANDIDATES_PER_NSPLIT]) | ||
|
||
# Sort the results by score. | ||
result.candidates.sort(key=lambda c: c.aggregate_score, reverse=True) | ||
return result | ||
|
||
|
||
# | ||
# Given a list of variables select only those SVs that do not deviate | ||
# from the best SV by more than a certain threshold. | ||
# | ||
def _pick_top_k(candidates: VarCandidates) -> int: | ||
k = 1 | ||
first = candidates.scores[0] | ||
for i in range(1, len(candidates.scores)): | ||
if first - candidates.scores[i] > _MULTI_SV_SCORE_DIFFERENTIAL: | ||
break | ||
k += 1 | ||
return k | ||
|
||
|
||
def _multivar_candidates_to_dict(candidates: MultiVarCandidates) -> Dict: | ||
result = {'Candidates': []} | ||
for c in candidates.candidates: | ||
c_dict = { | ||
'Parts': [], | ||
'AggCosineScore': round(c.aggregate_score, 4), | ||
'DelimBased': c.delim_based, | ||
} | ||
for p in c.parts: | ||
p_dict = {'QueryPart': p.query_part, 'SV': p.svs, 'CosineScore': p.scores} | ||
c_dict['Parts'].append(p_dict) | ||
result['Candidates'].append(c_dict) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might not worth a change for now, but why not having a list of {sv, score, sentences} struct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been that way from before, so not changing it now.