Skip to content

Commit

Permalink
update search to work with torch/numpy (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
ad12 authored May 20, 2023
1 parent 9ba766e commit a2bfa55
Showing 1 changed file with 51 additions and 21 deletions.
72 changes: 51 additions & 21 deletions meerkat/ops/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np

from meerkat import DataFrame, NumPyTensorColumn, TensorColumn, TorchTensorColumn
from meerkat.env import is_torch_available
from meerkat.interactive.graph.reactivity import reactive
from meerkat.tools.lazy_loader import LazyLoader

torch = LazyLoader("torch")
Expand All @@ -11,28 +13,32 @@
import torch


@reactive
def search(
data: DataFrame,
query: np.ndarray,
by: str = None,
k: int = None,
metric: str = "dot",
score_column: str = None,
) -> DataFrame:
"""Compute a sort a DataFrame. If a DataFrame, sort by the values in the
specified columns. Similar to ``sort_values`` in pandas.
"""Search by a query in a DataFrame.
Args:
data (Union[DataFrame, AbstractColumn]): DataFrame or Column to sort.
by (Union[str, List[str]]): The columns to sort by. Ignored if data is a Column.
ascending (Union[bool, List[bool]]): Whether to sort in ascending or
descending order. If a list, must be the same length as `by`.Defaults
to True.
kind (str): The kind of sort to use. Defaults to 'quicksort'. Options
include 'quicksort', 'mergesort', 'heapsort', 'stable'.
data: The DataFrame to search.
query: The query to search with.
by: The column to compare the query against.
k: The number of results to return.
metric: The metric to use for comparison.
score_column: The name of the column to store the scores in.
If ``None``, the scores will not be stored.
Return:
DataFrame: A sorted view of DataFrame.
"""
if len(data) <= 1:
raise ValueError("Dataframe must have at least 2 rows.")

by = data[by]

if not isinstance(by, TensorColumn):
Expand All @@ -44,37 +50,61 @@ def search(

if not torch.is_tensor(query):
query = torch.tensor(query)
query = query.to(by.device)

fn = _torch_search

elif isinstance(by, NumPyTensorColumn):
if torch.is_tensor(query):
query = query.detach().cpu().numpy()
if is_torch_available():
import torch

if torch.is_tensor(query):
query = query.detach().cpu().numpy()
elif not isinstance(query, np.ndarray):
query = np.array(query)

fn = _numpy_search
else:
raise ValueError("")

_, indices = fn(query=query, by=by.data, metric=metric, k=k)
return data[indices]
scores, indices = fn(query=query, by=by.data, metric=metric, k=k)
data = data[indices]
if score_column is not None:
data[score_column] = scores
return data


def _torch_search(
query: "torch.Tensor", by: "torch.Tensor", metric: str, k: int
) -> "torch.Tensor":
if len(query.shape) == 1:
query = query.unsqueeze(0)
with torch.no_grad():
if len(query.shape) == 1:
query = query.unsqueeze(0)

if metric == "dot":
scores = (by @ query.T).squeeze()
else:
raise ValueError("")

scores, indices = torch.topk(scores, k=k)
return scores.to("cpu").numpy(), indices.to("cpu")


def _numpy_search(query: np.ndarray, by: np.ndarray, metric: str, k: int) -> np.ndarray:
if query.ndim == 1:
query = query[np.newaxis, ...]

if metric == "dot":
scores = torch.matmul(by, query.T).squeeze()
scores = np.squeeze(by @ query.T)
else:
raise ValueError("")

scores, indices = torch.topk(scores, k=k)
return scores, indices

if k is not None:
indices = np.argpartition(scores, -k)[-k:]
indices = indices[np.argsort(-scores[indices])]
scores = scores[indices]
else:
indices = np.argsort(-scores)
scores = scores[indices]

def _numpy_search(query: "torch.Tensor", by: "torch.Tensor", metric: str, k: int):
raise NotImplementedError()
return scores, indices

0 comments on commit a2bfa55

Please sign in to comment.