From 15664bc37f1c3517df99b1e922e06b754563764f Mon Sep 17 00:00:00 2001 From: Julio Perez <37191411+jperez999@users.noreply.github.com> Date: Tue, 12 Apr 2022 18:00:31 -0400 Subject: [PATCH] added category name to domain for column properties (#1508) * added category name to domain for column properties * refactored low cardinality op * fix typo in categorify clear_stats def --- nvtabular/ops/categorify.py | 50 ++++++++++---------- nvtabular/ops/drop_low_cardinality.py | 68 +++++++++++++++++++++------ tests/unit/ops/test_categorify.py | 37 +++++++++++++++ 3 files changed, 117 insertions(+), 38 deletions(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 5f8ced620b0..62b706c6c17 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -363,23 +363,25 @@ def fit(self, col_selector: ColumnSelector, ddf: dd.DataFrame): def fit_finalize(self, categories): idx_count = 0 - for col in categories: + + for cat in categories: # this is a path - self.categories[col] = categories[col] + self.categories[cat] = categories[cat] # check the argument if self.single_table: - cat_file_path = self.categories[col] + cat_file_path = self.categories[cat] idx_count, new_cat_file_path = run_on_worker( - _reset_df_index, col, cat_file_path, idx_count + _reset_df_index, cat, cat_file_path, idx_count ) - self.categories[col] = new_cat_file_path + self.categories[cat] = new_cat_file_path def clear(self): + """Clear the internal state of the operator's stats.""" self.categories = deepcopy(self.vocabs) def process_vocabs(self, vocabs): + """Process vocabs passed in by the user.""" categories = {} - if isinstance(vocabs, dict) and all(dispatch.is_series_object(v) for v in vocabs.values()): fit_options = self._create_fit_options_from_columns(list(vocabs.keys())) base_path = os.path.join(self.out_path, fit_options.stat_name) @@ -503,26 +505,26 @@ def _compute_properties(self, col_schema, input_schema): new_schema = super()._compute_properties(col_schema, input_schema) col_name = col_schema.name - target_column_path = self.categories.get(col_name, None) + category_name = self.storage_name.get(col_name, col_name) + target_category_path = self.categories.get(category_name, None) + cardinality, dimensions = self.get_embedding_sizes([col_name])[col_name] - to_add = {} - if target_column_path: - to_add = { - "num_buckets": self.num_buckets[col_name] - if isinstance(self.num_buckets, dict) - else self.num_buckets, - "freq_threshold": self.freq_threshold[col_name] - if isinstance(self.freq_threshold, dict) - else self.freq_threshold, - "max_size": self.max_size[col_name] - if isinstance(self.max_size, dict) - else self.max_size, - "start_index": self.start_index, - "cat_path": target_column_path, - "domain": {"min": 0, "max": cardinality}, - "embedding_sizes": {"cardinality": cardinality, "dimension": dimensions}, - } + to_add = { + "num_buckets": self.num_buckets[col_name] + if isinstance(self.num_buckets, dict) + else self.num_buckets, + "freq_threshold": self.freq_threshold[col_name] + if isinstance(self.freq_threshold, dict) + else self.freq_threshold, + "max_size": self.max_size[col_name] + if isinstance(self.max_size, dict) + else self.max_size, + "start_index": self.start_index, + "cat_path": target_category_path, + "domain": {"min": 0, "max": cardinality, "name": category_name}, + "embedding_sizes": {"cardinality": cardinality, "dimension": dimensions}, + } return col_schema.with_properties({**new_schema.properties, **to_add}) diff --git a/nvtabular/ops/drop_low_cardinality.py b/nvtabular/ops/drop_low_cardinality.py index cf5ac13f931..4a21945736f 100644 --- a/nvtabular/ops/drop_low_cardinality.py +++ b/nvtabular/ops/drop_low_cardinality.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from merlin.core.dispatch import DataFrameType, annotate +from merlin.core.dispatch import DataFrameType from merlin.schema import Schema, Tags from .operator import ColumnSelector, Operator @@ -29,22 +29,62 @@ class DropLowCardinality(Operator): def __init__(self, min_cardinality=2): super().__init__() self.min_cardinality = min_cardinality - self.to_drop = [] - @annotate("drop_low_cardinality", color="darkgreen", domain="nvt_python") def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFrameType: - return df.drop(self.to_drop, axis=1) + """ + Selects all non-categorical columns and any categorical columns + of at least the minimum cardinality from the dataframe. + + Parameters + ---------- + col_selector : ColumnSelector + The columns to select. + df : DataFrameType + The dataframe to transform + + Returns + ------- + DataFrameType + Dataframe with only the selected columns. + """ + return super()._get_columns(df, col_selector) + + def compute_selector( + self, + input_schema: Schema, + selector: ColumnSelector, + parents_selector: ColumnSelector, + dependencies_selector: ColumnSelector, + ) -> ColumnSelector: + """ + Checks the cardinality of the input columns and drops any categorical + columns with cardinality less than the specified minimum. + + Parameters + ---------- + input_schema : Schema + The current node's input schema + selector : ColumnSelector + The current node's selector + parents_selector : ColumnSelector + A selector for the output columns of the current node's parents + dependencies_selector : ColumnSelector + A selector for the output columns of the current node's dependencies + + Returns + ------- + ColumnSelector + Selector that contains all non-categorical columns and any categorical columns + of at least the minimum cardinality. + """ + self._validate_matching_cols(input_schema, selector, self.compute_selector.__name__) + + cols_to_keep = [col for col in input_schema if Tags.CATEGORICAL not in col.tags] - def compute_output_schema(self, input_schema, selector, prev_output_schema=None): - output_columns = [] for col in input_schema: if Tags.CATEGORICAL in col.tags: domain = col.int_domain - if domain and domain.max <= self.min_cardinality: - self.to_drop.append(col.name) - continue - output_columns.append(col) - return Schema(output_columns) - - transform.__doc__ = Operator.transform.__doc__ - compute_output_schema.__doc__ = Operator.compute_output_schema.__doc__ + if not domain or domain.max > self.min_cardinality: + cols_to_keep.append(col.name) + + return ColumnSelector(cols_to_keep) diff --git a/tests/unit/ops/test_categorify.py b/tests/unit/ops/test_categorify.py index c0b19fbe935..936fbeff9c4 100644 --- a/tests/unit/ops/test_categorify.py +++ b/tests/unit/ops/test_categorify.py @@ -577,3 +577,40 @@ def test_categorify_no_nulls(): df = pd.read_parquet("./categories/unique.user_id.parquet") assert df["user_id"].iloc[:1].isnull().any() assert df["user_id_size"][0] == 0 + + +@pytest.mark.parametrize("cat_names", [[["Author", "Engaging User"]], ["Author", "Engaging User"]]) +@pytest.mark.parametrize("kind", ["joint", "combo"]) +@pytest.mark.parametrize("cpu", _CPU) +def test_categorify_domain_name(tmpdir, cat_names, kind, cpu): + df = pd.DataFrame( + { + "Author": ["User_A", "User_E", "User_B", "User_C"], + "Engaging User": ["User_B", "User_B", "User_A", "User_D"], + "Post": [1, 2, 3, 4], + } + ) + cats = cat_names >> ops.Categorify(out_path=str(tmpdir), encode_type=kind) + + workflow = nvt.Workflow(cats) + workflow.fit_transform(nvt.Dataset(df, cpu=cpu)).to_ddf().compute(scheduler="synchronous") + + domain_names = [] + for col_name in workflow.output_schema.column_names: + domain_names.append(workflow.output_schema[col_name].properties["domain"]["name"]) + + assert workflow.output_schema[col_name].properties != {} + assert "domain" in workflow.output_schema[col_name].properties + assert "name" in workflow.output_schema[col_name].properties["domain"] + + if len(cat_names) == 1 and kind == "combo": + # Columns are encoded in combination, so there's only one domain name + assert len(domain_names) == 1 + assert domain_names[0] == "Author_Engaging User" + else: + if len(cat_names) == 1 and kind == "joint": + # Columns are encoded jointly, so the domain names are the same + assert len(set(domain_names)) == 1 + else: + # Columns are encoded independently, so the domain names are different + assert len(set(domain_names)) > 1