Skip to content

Commit

Permalink
chore: Semantic operations - support non-string types, fix flaky top_…
Browse files Browse the repository at this point in the history
…k doctests (#1099)
  • Loading branch information
chelsea-lin authored Oct 22, 2024
1 parent 9aff171 commit 2a0ffac
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 62 deletions.
85 changes: 41 additions & 44 deletions bigframes/operations/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,20 @@ def agg(
ValueError: when the instruction refers to a non-existing column, or when
more than one columns are referred to.
"""
self._validate_model(model)
import bigframes.bigquery as bbq
import bigframes.dataframe
import bigframes.series

self._validate_model(model)
columns = self._parse_columns(instruction)

df: bigframes.dataframe.DataFrame = self._df.copy()
for column in columns:
if column not in self._df.columns:
raise ValueError(f"Column {column} not found.")
if self._df[column].dtype != dtypes.STRING_DTYPE:
raise TypeError(
"Semantics aggregated column must be a string type, not "
f"{type(self._df[column])}"
)

if df[column].dtype != dtypes.STRING_DTYPE:
df[column] = df[column].astype(dtypes.STRING_DTYPE)

if len(columns) > 1:
raise NotImplementedError(
Expand All @@ -122,11 +125,6 @@ def agg(
"It must be greater than 1."
)

import bigframes.bigquery as bbq
import bigframes.dataframe
import bigframes.series

df: bigframes.dataframe.DataFrame = self._df.copy()
user_instruction = self._format_instruction(instruction, columns)

num_cluster = 1
Expand Down Expand Up @@ -325,26 +323,27 @@ def filter(self, instruction: str, model):
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
import bigframes.dataframe
import bigframes.series

self._validate_model(model)
columns = self._parse_columns(instruction)
for column in columns:
if column not in self._df.columns:
raise ValueError(f"Column {column} not found.")
if self._df[column].dtype != dtypes.STRING_DTYPE:
raise TypeError(
"Semantics aggregated column must be a string type, not "
f"{type(self._df[column])}"
)

df: bigframes.dataframe.DataFrame = self._df[columns].copy()
for column in columns:
if df[column].dtype != dtypes.STRING_DTYPE:
df[column] = df[column].astype(dtypes.STRING_DTYPE)

user_instruction = self._format_instruction(instruction, columns)
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"

from bigframes.dataframe import DataFrame

results = typing.cast(
DataFrame,
bigframes.dataframe.DataFrame,
model.predict(
self._make_prompt(columns, user_instruction, output_instruction),
self._make_prompt(df, columns, user_instruction, output_instruction),
temperature=0.0,
),
)
Expand Down Expand Up @@ -398,28 +397,29 @@ def map(self, instruction: str, output_column: str, model):
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
import bigframes.dataframe
import bigframes.series

self._validate_model(model)
columns = self._parse_columns(instruction)
for column in columns:
if column not in self._df.columns:
raise ValueError(f"Column {column} not found.")
if self._df[column].dtype != dtypes.STRING_DTYPE:
raise TypeError(
"Semantics aggregated column must be a string type, not "
f"{type(self._df[column])}"
)

df: bigframes.dataframe.DataFrame = self._df[columns].copy()
for column in columns:
if df[column].dtype != dtypes.STRING_DTYPE:
df[column] = df[column].astype(dtypes.STRING_DTYPE)

user_instruction = self._format_instruction(instruction, columns)
output_instruction = (
"Based on the provided contenxt, answer the following instruction:"
)

from bigframes.series import Series

results = typing.cast(
Series,
bigframes.series.Series,
model.predict(
self._make_prompt(columns, user_instruction, output_instruction),
self._make_prompt(df, columns, user_instruction, output_instruction),
temperature=0.0,
)["ml_generate_text_llm_result"],
)
Expand Down Expand Up @@ -683,6 +683,9 @@ def top_k(self, instruction: str, model, k=10):
ValueError: when the instruction refers to a non-existing column, or when no
columns are referred to.
"""
import bigframes.dataframe
import bigframes.series

self._validate_model(model)
columns = self._parse_columns(instruction)
for column in columns:
Expand All @@ -692,12 +695,12 @@ def top_k(self, instruction: str, model, k=10):
raise NotImplementedError(
"Semantic aggregations are limited to a single column."
)

df: bigframes.dataframe.DataFrame = self._df[columns].copy()
column = columns[0]
if self._df[column].dtype != dtypes.STRING_DTYPE:
raise TypeError(
"Referred column must be a string type, not "
f"{type(self._df[column])}"
)
if df[column].dtype != dtypes.STRING_DTYPE:
df[column] = df[column].astype(dtypes.STRING_DTYPE)

# `index` is reserved for the `reset_index` below.
if column == "index":
raise ValueError(
Expand All @@ -709,12 +712,7 @@ def top_k(self, instruction: str, model, k=10):

user_instruction = self._format_instruction(instruction, columns)

import bigframes.dataframe
import bigframes.series

df: bigframes.dataframe.DataFrame = self._df[columns].copy()
n = df.shape[0]

if k >= n:
return df

Expand Down Expand Up @@ -762,17 +760,17 @@ def _topk_partition(

# Random pivot selection for improved average quickselect performance.
pending_df = df[df[status_column].isna()]
pivot_iloc = np.random.randint(0, pending_df.shape[0] - 1)
pivot_iloc = np.random.randint(0, pending_df.shape[0])
pivot_index = pending_df.iloc[pivot_iloc]["index"]
pivot_df = pending_df[pending_df["index"] == pivot_index]

# Build a prompt to compare the pivot item's relevance to other pending items.
prompt_s = pending_df[pending_df["index"] != pivot_index][column]
prompt_s = (
f"{output_instruction}\n\nQuestion: {user_instruction}\n"
+ "\nDocument 1: "
+ f"\nDocument 1: {column} "
+ pivot_df.iloc[0][column]
+ "\nDocument 2: "
+ f"\nDocument 2: {column} "
+ prompt_s # type:ignore
)

Expand Down Expand Up @@ -920,9 +918,8 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
return result_df

def _make_prompt(
self, columns: List[str], user_instruction: str, output_instruction: str
self, prompt_df, columns, user_instruction: str, output_instruction: str
):
prompt_df = self._df[columns].copy()
prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: "

# Combine context from multiple columns.
Expand Down
57 changes: 39 additions & 18 deletions tests/system/large/operations/test_semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_semantics_experiment_off_raise_error():
pytest.param(2, None, id="two"),
pytest.param(3, None, id="three"),
pytest.param(4, None, id="four"),
pytest.param(5, "Year", id="two_w_cluster_column"),
pytest.param(6, "Year", id="three_w_cluster_column"),
pytest.param(7, "Year", id="four_w_cluster_column"),
pytest.param(5, "Years", id="two_w_cluster_column"),
pytest.param(6, "Years", id="three_w_cluster_column"),
pytest.param(7, "Years", id="four_w_cluster_column"),
],
)
def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
Expand All @@ -56,7 +56,7 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
"Shuttle Island",
"The Great Gatsby",
],
"Year": [1997, 2013, 2023, 2015, 2010, 2010, 2013],
"Years": [1997, 2013, 2023, 2015, 2010, 2010, 2013],
},
session=session,
)
Expand All @@ -73,6 +73,29 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False)


def test_agg_w_int_column(session, gemini_flash_model):
bigframes.options.experiments.semantic_operators = True
df = dataframe.DataFrame(
data={
"Movies": [
"Killers of the Flower Moon",
"The Great Gatsby",
],
"Years": [2023, 2013],
},
session=session,
)
instruction = "Find the {Years} Leonardo DiCaprio acted in the most movies. Answer with the year only."
actual_s = df.semantics.agg(
instruction,
model=gemini_flash_model,
).to_pandas()

expected_s = pd.Series(["2013 \n"], dtype=dtypes.STRING_DTYPE)
expected_s.name = "Years"
pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False)


@pytest.mark.parametrize(
"instruction",
[
Expand All @@ -91,11 +114,6 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
id="two_columns",
marks=pytest.mark.xfail(raises=NotImplementedError),
),
pytest.param(
"{Year}",
id="invalid_type",
marks=pytest.mark.xfail(raises=TypeError),
),
],
)
def test_agg_invalid_instruction_raise_error(instruction, gemini_flash_model):
Expand Down Expand Up @@ -207,15 +225,21 @@ def test_cluster_by_invalid_model(session, gemini_flash_model):
def test_filter(session, gemini_flash_model):
bigframes.options.experiments.semantic_operators = True
df = dataframe.DataFrame(
data={"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]},
data={
"country": ["USA", "Germany"],
"city": ["Seattle", "Berlin"],
"year": [2023, 2024],
},
session=session,
)

actual_df = df.semantics.filter(
"{city} is the capital of {country}", gemini_flash_model
"{city} is the capital of {country} in {year}", gemini_flash_model
).to_pandas()

expected_df = pd.DataFrame({"country": ["Germany"], "city": ["Berlin"]}, index=[1])
expected_df = pd.DataFrame(
{"country": ["Germany"], "city": ["Berlin"], "year": [2024]}, index=[1]
)
pandas.testing.assert_frame_equal(
actual_df, expected_df, check_dtype=False, check_index_type=False
)
Expand Down Expand Up @@ -282,12 +306,13 @@ def test_map(session, gemini_flash_model):
data={
"ingredient_1": ["Burger Bun", "Soy Bean"],
"ingredient_2": ["Beef Patty", "Bittern"],
"gluten-free": [True, True],
},
session=session,
)

actual_df = df.semantics.map(
"What is the food made from {ingredient_1} and {ingredient_2}? One word only.",
"What is the {gluten-free} food made from {ingredient_1} and {ingredient_2}? One word only.",
"food",
gemini_flash_model,
).to_pandas()
Expand All @@ -298,6 +323,7 @@ def test_map(session, gemini_flash_model):
{
"ingredient_1": ["Burger Bun", "Soy Bean"],
"ingredient_2": ["Beef Patty", "Bittern"],
"gluten-free": [True, True],
"food": ["burger", "tofu"],
}
)
Expand Down Expand Up @@ -724,11 +750,6 @@ def test_sim_join_data_too_large_raises_error(session, text_embedding_generator)
id="two_columns",
marks=pytest.mark.xfail(raises=NotImplementedError),
),
pytest.param(
"{ID}",
id="invalid_dtypes",
marks=pytest.mark.xfail(raises=TypeError),
),
pytest.param(
"{index}",
id="preserved",
Expand Down

0 comments on commit 2a0ffac

Please sign in to comment.