diff --git a/bigframes/operations/semantics.py b/bigframes/operations/semantics.py index 7e8e5a8093..cff3fc724d 100644 --- a/bigframes/operations/semantics.py +++ b/bigframes/operations/semantics.py @@ -98,17 +98,20 @@ def agg( ValueError: when the instruction refers to a non-existing column, or when more than one columns are referred to. """ - self._validate_model(model) + import bigframes.bigquery as bbq + import bigframes.dataframe + import bigframes.series + self._validate_model(model) columns = self._parse_columns(instruction) + + df: bigframes.dataframe.DataFrame = self._df.copy() for column in columns: if column not in self._df.columns: raise ValueError(f"Column {column} not found.") - if self._df[column].dtype != dtypes.STRING_DTYPE: - raise TypeError( - "Semantics aggregated column must be a string type, not " - f"{type(self._df[column])}" - ) + + if df[column].dtype != dtypes.STRING_DTYPE: + df[column] = df[column].astype(dtypes.STRING_DTYPE) if len(columns) > 1: raise NotImplementedError( @@ -122,11 +125,6 @@ def agg( "It must be greater than 1." ) - import bigframes.bigquery as bbq - import bigframes.dataframe - import bigframes.series - - df: bigframes.dataframe.DataFrame = self._df.copy() user_instruction = self._format_instruction(instruction, columns) num_cluster = 1 @@ -325,26 +323,27 @@ def filter(self, instruction: str, model): ValueError: when the instruction refers to a non-existing column, or when no columns are referred to. """ + import bigframes.dataframe + import bigframes.series + self._validate_model(model) columns = self._parse_columns(instruction) for column in columns: if column not in self._df.columns: raise ValueError(f"Column {column} not found.") - if self._df[column].dtype != dtypes.STRING_DTYPE: - raise TypeError( - "Semantics aggregated column must be a string type, not " - f"{type(self._df[column])}" - ) + + df: bigframes.dataframe.DataFrame = self._df[columns].copy() + for column in columns: + if df[column].dtype != dtypes.STRING_DTYPE: + df[column] = df[column].astype(dtypes.STRING_DTYPE) user_instruction = self._format_instruction(instruction, columns) output_instruction = "Based on the provided context, reply to the following claim by only True or False:" - from bigframes.dataframe import DataFrame - results = typing.cast( - DataFrame, + bigframes.dataframe.DataFrame, model.predict( - self._make_prompt(columns, user_instruction, output_instruction), + self._make_prompt(df, columns, user_instruction, output_instruction), temperature=0.0, ), ) @@ -398,28 +397,29 @@ def map(self, instruction: str, output_column: str, model): ValueError: when the instruction refers to a non-existing column, or when no columns are referred to. """ + import bigframes.dataframe + import bigframes.series + self._validate_model(model) columns = self._parse_columns(instruction) for column in columns: if column not in self._df.columns: raise ValueError(f"Column {column} not found.") - if self._df[column].dtype != dtypes.STRING_DTYPE: - raise TypeError( - "Semantics aggregated column must be a string type, not " - f"{type(self._df[column])}" - ) + + df: bigframes.dataframe.DataFrame = self._df[columns].copy() + for column in columns: + if df[column].dtype != dtypes.STRING_DTYPE: + df[column] = df[column].astype(dtypes.STRING_DTYPE) user_instruction = self._format_instruction(instruction, columns) output_instruction = ( "Based on the provided contenxt, answer the following instruction:" ) - from bigframes.series import Series - results = typing.cast( - Series, + bigframes.series.Series, model.predict( - self._make_prompt(columns, user_instruction, output_instruction), + self._make_prompt(df, columns, user_instruction, output_instruction), temperature=0.0, )["ml_generate_text_llm_result"], ) @@ -683,6 +683,9 @@ def top_k(self, instruction: str, model, k=10): ValueError: when the instruction refers to a non-existing column, or when no columns are referred to. """ + import bigframes.dataframe + import bigframes.series + self._validate_model(model) columns = self._parse_columns(instruction) for column in columns: @@ -692,12 +695,12 @@ def top_k(self, instruction: str, model, k=10): raise NotImplementedError( "Semantic aggregations are limited to a single column." ) + + df: bigframes.dataframe.DataFrame = self._df[columns].copy() column = columns[0] - if self._df[column].dtype != dtypes.STRING_DTYPE: - raise TypeError( - "Referred column must be a string type, not " - f"{type(self._df[column])}" - ) + if df[column].dtype != dtypes.STRING_DTYPE: + df[column] = df[column].astype(dtypes.STRING_DTYPE) + # `index` is reserved for the `reset_index` below. if column == "index": raise ValueError( @@ -709,12 +712,7 @@ def top_k(self, instruction: str, model, k=10): user_instruction = self._format_instruction(instruction, columns) - import bigframes.dataframe - import bigframes.series - - df: bigframes.dataframe.DataFrame = self._df[columns].copy() n = df.shape[0] - if k >= n: return df @@ -762,7 +760,7 @@ def _topk_partition( # Random pivot selection for improved average quickselect performance. pending_df = df[df[status_column].isna()] - pivot_iloc = np.random.randint(0, pending_df.shape[0] - 1) + pivot_iloc = np.random.randint(0, pending_df.shape[0]) pivot_index = pending_df.iloc[pivot_iloc]["index"] pivot_df = pending_df[pending_df["index"] == pivot_index] @@ -770,9 +768,9 @@ def _topk_partition( prompt_s = pending_df[pending_df["index"] != pivot_index][column] prompt_s = ( f"{output_instruction}\n\nQuestion: {user_instruction}\n" - + "\nDocument 1: " + + f"\nDocument 1: {column} " + pivot_df.iloc[0][column] - + "\nDocument 2: " + + f"\nDocument 2: {column} " + prompt_s # type:ignore ) @@ -920,9 +918,8 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode return result_df def _make_prompt( - self, columns: List[str], user_instruction: str, output_instruction: str + self, prompt_df, columns, user_instruction: str, output_instruction: str ): - prompt_df = self._df[columns].copy() prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: " # Combine context from multiple columns. diff --git a/tests/system/large/operations/test_semantics.py b/tests/system/large/operations/test_semantics.py index 72f2897211..3b9cfcf4c3 100644 --- a/tests/system/large/operations/test_semantics.py +++ b/tests/system/large/operations/test_semantics.py @@ -38,9 +38,9 @@ def test_semantics_experiment_off_raise_error(): pytest.param(2, None, id="two"), pytest.param(3, None, id="three"), pytest.param(4, None, id="four"), - pytest.param(5, "Year", id="two_w_cluster_column"), - pytest.param(6, "Year", id="three_w_cluster_column"), - pytest.param(7, "Year", id="four_w_cluster_column"), + pytest.param(5, "Years", id="two_w_cluster_column"), + pytest.param(6, "Years", id="three_w_cluster_column"), + pytest.param(7, "Years", id="four_w_cluster_column"), ], ) def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column): @@ -56,7 +56,7 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column): "Shuttle Island", "The Great Gatsby", ], - "Year": [1997, 2013, 2023, 2015, 2010, 2010, 2013], + "Years": [1997, 2013, 2023, 2015, 2010, 2010, 2013], }, session=session, ) @@ -73,6 +73,29 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column): pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False) +def test_agg_w_int_column(session, gemini_flash_model): + bigframes.options.experiments.semantic_operators = True + df = dataframe.DataFrame( + data={ + "Movies": [ + "Killers of the Flower Moon", + "The Great Gatsby", + ], + "Years": [2023, 2013], + }, + session=session, + ) + instruction = "Find the {Years} Leonardo DiCaprio acted in the most movies. Answer with the year only." + actual_s = df.semantics.agg( + instruction, + model=gemini_flash_model, + ).to_pandas() + + expected_s = pd.Series(["2013 \n"], dtype=dtypes.STRING_DTYPE) + expected_s.name = "Years" + pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False) + + @pytest.mark.parametrize( "instruction", [ @@ -91,11 +114,6 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column): id="two_columns", marks=pytest.mark.xfail(raises=NotImplementedError), ), - pytest.param( - "{Year}", - id="invalid_type", - marks=pytest.mark.xfail(raises=TypeError), - ), ], ) def test_agg_invalid_instruction_raise_error(instruction, gemini_flash_model): @@ -207,15 +225,21 @@ def test_cluster_by_invalid_model(session, gemini_flash_model): def test_filter(session, gemini_flash_model): bigframes.options.experiments.semantic_operators = True df = dataframe.DataFrame( - data={"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]}, + data={ + "country": ["USA", "Germany"], + "city": ["Seattle", "Berlin"], + "year": [2023, 2024], + }, session=session, ) actual_df = df.semantics.filter( - "{city} is the capital of {country}", gemini_flash_model + "{city} is the capital of {country} in {year}", gemini_flash_model ).to_pandas() - expected_df = pd.DataFrame({"country": ["Germany"], "city": ["Berlin"]}, index=[1]) + expected_df = pd.DataFrame( + {"country": ["Germany"], "city": ["Berlin"], "year": [2024]}, index=[1] + ) pandas.testing.assert_frame_equal( actual_df, expected_df, check_dtype=False, check_index_type=False ) @@ -282,12 +306,13 @@ def test_map(session, gemini_flash_model): data={ "ingredient_1": ["Burger Bun", "Soy Bean"], "ingredient_2": ["Beef Patty", "Bittern"], + "gluten-free": [True, True], }, session=session, ) actual_df = df.semantics.map( - "What is the food made from {ingredient_1} and {ingredient_2}? One word only.", + "What is the {gluten-free} food made from {ingredient_1} and {ingredient_2}? One word only.", "food", gemini_flash_model, ).to_pandas() @@ -298,6 +323,7 @@ def test_map(session, gemini_flash_model): { "ingredient_1": ["Burger Bun", "Soy Bean"], "ingredient_2": ["Beef Patty", "Bittern"], + "gluten-free": [True, True], "food": ["burger", "tofu"], } ) @@ -724,11 +750,6 @@ def test_sim_join_data_too_large_raises_error(session, text_embedding_generator) id="two_columns", marks=pytest.mark.xfail(raises=NotImplementedError), ), - pytest.param( - "{ID}", - id="invalid_dtypes", - marks=pytest.mark.xfail(raises=TypeError), - ), pytest.param( "{index}", id="preserved",