Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-sv detection #2510

Merged
merged 9 commits into from
Mar 30, 2023
288 changes: 244 additions & 44 deletions nl_server/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Managing the embeddings."""
from dataclasses import dataclass
import logging
import os
import re
from typing import Dict, List, Union

from datasets import load_dataset
Expand All @@ -22,11 +24,60 @@
from sentence_transformers.util import semantic_search
import torch

from nl_server import query_util
import nl_server.gcs as gcs
from server.lib.nl import utils

TEMP_DIR = '/tmp/'
MODEL_NAME = 'all-MiniLM-L6-v2'

# A value higher than the highest score.
_HIGHEST_SCORE = 1.0
_INIT_SCORE = (_HIGHEST_SCORE + 0.1)

# Scores below this are ignored.
_SV_SCORE_THRESHOLD = 0.5

# If the difference between successive scores exceeds this threshold, then SVs at
# the lower score and below are ignored.
_MULTI_SV_SCORE_DIFFERENTIAL = 0.05

_NUM_CANDIDATES_PER_NSPLIT = 3


# List of SV candidates, along with scores.
@dataclass
class VarCandidates:
# The below are sorted and parallel lists.
svs: List[str]
scores: List[float]
sv2sentences: Dict[str, List[str]]
Comment on lines +52 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not worth a change for now, but why not having a list of {sv, score, sentences} struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been that way from before, so not changing it now.



# One part of a single multi-var candidate and its
# associated SVs and scores.
@dataclass
class MultiVarCandidatePart:
query_part: str
svs: List[str]
scores: List[float]


# One multi-var candidate containing multiple parts.
@dataclass
class MultiVarCandidate:
parts: List[MultiVarCandidatePart]
# Aggregate score
aggregate_score: float
# Is this candidate based on a split computed from delimiters?
delim_based: bool


# List of multi-var candidates.
@dataclass
class MultiVarCandidates:
candidates: List[MultiVarCandidate]


class Embeddings:
"""Manages the embeddings."""
Expand Down Expand Up @@ -80,52 +131,201 @@ def get_embedding_at_index(self, index: int) -> List[float]:
def get_embedding(self, query: str) -> List[float]:
return self.model.encode(query).tolist()

def detect_svs(self, query: str) -> Dict[str, Union[Dict, List]]:
query_embeddings = self.model.encode([query])
#
# Given a list of queries, searches the in-memory embeddings index
# and returns a map of candidates keyed by input queries.
#
def _search_embeddings(self, queries: List[str]) -> Dict[str, VarCandidates]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why do the refactor of this function in the same PR as the introduction of a new function for detecting multi svs? Is there some functionality that was missing which this refactor addresses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactor takes multiple queries and returns results, which was needed for the multi SV case.

query_embeddings = self.model.encode(queries)
hits = semantic_search(query_embeddings, self.dataset_embeddings, top_k=20)

# Note: multiple results may map to the same DCID. As well, the same string may
# map to multiple DCIDs with the same score.
sv2score = {}
# Also track the sv to index so that embeddings can later be retrieved.
sv2index = {}
# Also add the full list of SVs and sentences that matched (for debugging).
all_svs_sentences: Dict[str, List[str]] = {}
for e in hits[0]:
for d in self.dcids[e['corpus_id']].split(','):
s = e['score']
ind = e['corpus_id']
sentence = ""
try:
sentence = self.sentences[e['corpus_id']] + f" ({s})"
except Exception as exp:
logging.info(exp)
# Prefer the top score.
if d not in sv2score:
sv2score[d] = s
sv2index[d] = ind

# Add to the debug map anyway.
existing_sentences = []
if d in all_svs_sentences:
existing_sentences = all_svs_sentences[d]

if sentence not in existing_sentences:
existing_sentences.append(sentence)
all_svs_sentences[d] = existing_sentences

# Sort by scores
sv2score_sorted = sorted(sv2score.items(),
key=lambda item: item[1],
reverse=True)
svs_sorted = [k for (k, _) in sv2score_sorted]
scores_sorted = [v for (_, v) in sv2score_sorted]

sv_index_sorted = [sv2index[k] for (k, _) in sv2score_sorted]
# A map from input query -> SV DCID -> matched sentence -> score for that match
query2sv2sentence2score: Dict[str, Dict[str, Dict[str, float]]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have a comment with an example of query2sv2sentence2score?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Described it.

# A map from input query -> SV DCID -> highest matched score
query2sv2score: Dict[str, Dict[str, float]] = {}
for i, hit in enumerate(hits):
q = queries[i]
query2sv2score[q] = {}
query2sv2sentence2score[q] = {}
for ent in hit:
score = ent['score']
for dcid in self.dcids[ent['corpus_id']].split(','):
# Prefer the top score.
if dcid not in query2sv2score[q]:
query2sv2score[q][dcid] = score
query2sv2sentence2score[q][dcid] = {}

if ent['corpus_id'] >= len(self.sentences):
continue
sentence = self.sentences[ent['corpus_id']]
query2sv2sentence2score[q][dcid][sentence] = score

query2result: Dict[str, VarCandidates] = {}

# Go over the map and prepare parallel lists of
# SVs and scores in query2result.
for q, sv2score in query2sv2score.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a comment for this and the next large for loop on what it processes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sv2score_sorted = [(k, v) for (
k,
v) in sorted(sv2score.items(), key=lambda item: item[1], reverse=True)
]
svs = [k for (k, _) in sv2score_sorted]
scores = [v for (_, v) in sv2score_sorted]
query2result[q] = VarCandidates(svs=svs, scores=scores, sv2sentences={})

# Go over the results and prepare the sv2sentences map in
# query2result.
for q, sv2sentence2score in query2sv2sentence2score.items():
query2result[q].sv2sentences = {}
for sv, sentence2score in sv2sentence2score.items():
query2result[q].sv2sentences[sv] = []
for sentence, score in sorted(sentence2score.items(),
key=lambda item: item[1],
reverse=True):
score = round(score, 4)
query2result[q].sv2sentences[sv].append(sentence + f' ({score})')

return query2result

#
# The main entry point to detect SVs.
#
def detect_svs(self, orig_query: str) -> Dict[str, Union[Dict, List]]:
# Remove all stop-words.
query_monovar = utils.remove_stop_words(orig_query,
query_util.ALL_STOP_WORDS)

# Search embeddings for a single SV.
result_monovar = self._search_embeddings([query_monovar])[query_monovar]

# Try to detect multiple SVs. Use the original query so that
# the logic can rely on stop-words like `vs`, `and`, etc as hints
# for SV delimiters.
result_multivar = self._detect_multiple_svs(orig_query)

# TODO: Rename SV_to_Sentences for consistency.
return {
'SV': svs_sorted,
'CosineScore': scores_sorted,
'EmbeddingIndex': sv_index_sorted,
'SV_to_Sentences': all_svs_sentences,
'SV': result_monovar.svs,
'CosineScore': result_monovar.scores,
'SV_to_Sentences': result_monovar.sv2sentences,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use CamelCase too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from before, and will touch a few unrelated places, so can change in a follow on PR. Added TODO

'MultiSV': _multivar_candidates_to_dict(result_multivar)
}

#
# Detects one or more SVs from the query.
# TODO: Fix the query upstream to ensure the punctuations aren't stripped.
#
def _detect_multiple_svs(self, query: str) -> MultiVarCandidates:
#
# Prepare a combination of query-sets.
#
querysets = query_util.prepare_multivar_querysets(query)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading the comments below about a queryset, wouldn't it be better to use this combinatorial API/function built in to python? https://www.geeksforgeeks.org/itertools-combinations-module-python-print-possible-combinations/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. doing in the follow on PR.


result = MultiVarCandidates(candidates=[])

# Make a unique list of query strings
all_queries = set()
for qs in querysets:
for c in qs.combinations:
for p in c.parts:
all_queries.add(p)
if not all_queries:
return result

query2result = self._search_embeddings(list(all_queries))

#
# A queryset is the set of all combinations of query
# splits of a given length. For example, a query like
# "hispanic women phd" may have a queryset for a
# 2-way split, like:
# QuerySet(nsplits=2,
# combinations=[
# ['hispanic women', 'phd'],
# ['hispanic', 'women phd'],
# ])
# The 3-way split has only one combination.
#
# We take the average score from the top SV from the query-parts in a
# queryset (ignoring any queryset with a score below threshold). Then
# sort all candidates by that score.
#
# TODO: Come up with a better ranking function.
#
for qs in querysets:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some examples of "querysets" is helpful, since its fields are very deep

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, PTAL.

candidates: List[MultiVarCandidate] = []
for c in qs.combinations:
if not c or not c.parts:
continue

total = 0
candidate = MultiVarCandidate(parts=[],
delim_based=qs.delim_based,
aggregate_score=-1)
lowest = _INIT_SCORE
for q in c.parts:
r = query2result.get(
q, VarCandidates(svs=[], scores=[], sv2sentences={}))
part = MultiVarCandidatePart(query_part=q, svs=[], scores=[])
score = 0
if r.svs:
# Pick the top-K SVs.
limit = _pick_top_k(r)
if limit > 0:
part.svs = r.svs[:limit]
part.scores = [round(s, 4) for s in r.scores[:limit]]
score = r.scores[0]

if score < lowest:
lowest = score
total += score
candidate.parts.append(part)

if lowest < _SV_SCORE_THRESHOLD:
# A query-part's best SV did not cross our score threshold,
# so drop this candidate.
continue

# The candidate level score is the average.
candidate.aggregate_score = total / len(c.parts)
candidates.append(candidate)

if candidates:
# Pick the top candidate.
candidates.sort(key=lambda c: c.aggregate_score, reverse=True)
# Pick upto some number of candidates. Could be just 1
# eventually.
result.candidates.extend(candidates[:_NUM_CANDIDATES_PER_NSPLIT])

# Sort the results by score.
result.candidates.sort(key=lambda c: c.aggregate_score, reverse=True)
return result


#
# Given a list of variables select only those SVs that do not deviate
# from the best SV by more than a certain threshold.
#
def _pick_top_k(candidates: VarCandidates) -> int:
k = 1
first = candidates.scores[0]
for i in range(1, len(candidates.scores)):
if first - candidates.scores[i] > _MULTI_SV_SCORE_DIFFERENTIAL:
break
k += 1
return k


def _multivar_candidates_to_dict(candidates: MultiVarCandidates) -> Dict:
result = {'Candidates': []}
for c in candidates.candidates:
c_dict = {
'Parts': [],
'AggCosineScore': round(c.aggregate_score, 4),
'DelimBased': c.delim_based,
}
for p in c.parts:
p_dict = {'QueryPart': p.query_part, 'SV': p.svs, 'CosineScore': p.scores}
c_dict['Parts'].append(p_dict)
result['Candidates'].append(c_dict)
return result
Loading