-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up the way shapes are computed and specified #1760
Changes from all commits
29b4790
1fd0f6f
8b4ed58
00f1820
60367e8
37bf594
ff01a15
4b0e26b
f6efae2
4d63c85
acc9256
0aed473
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
# limitations under the License. | ||
# | ||
|
||
from merlin.dtypes.shape import Shape | ||
from merlin.schema import ColumnSchema, Tags | ||
|
||
|
||
|
@@ -22,9 +23,10 @@ def _augment_schema( | |
cats=None, | ||
conts=None, | ||
labels=None, | ||
sparse_names=None, | ||
sparse_max=None, | ||
sparse_as_dense=False, | ||
padded_cols=None, | ||
padded_lengths=None, | ||
pad=False, | ||
batch_size=0, | ||
): | ||
labels = [labels] if isinstance(labels, str) else labels | ||
for label in labels or []: | ||
|
@@ -34,21 +36,20 @@ def _augment_schema( | |
for label in conts or []: | ||
schema[label] = schema[label].with_tags(Tags.CONTINUOUS) | ||
|
||
# Set the appropriate properties for the sparse_names/sparse_max/sparse_as_dense | ||
for col in sparse_names or []: | ||
for col in padded_cols or []: | ||
cs = schema[col] | ||
properties = cs.properties | ||
if sparse_max and col in sparse_max: | ||
properties["value_count"] = {"max": sparse_max[col]} | ||
if sparse_as_dense: | ||
properties["value_count"]["min"] = properties["value_count"]["max"] | ||
dims = Shape(((1, batch_size), None)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not required for anything specific, just following the principle that the schema should always accurately reflect the data to the greatest extent possible. Here we have shape information since we know the batch size, so we fill it in in case that helps something downstream. I don't know if it actually will, but it seemed like the right thing to do. |
||
|
||
if not cs.shape.dims[1].is_unknown: | ||
dims = dims.with_dim(1, cs.shape.dims[1]) | ||
|
||
if pad: | ||
dims = dims.with_dim_min(1, padded_lengths[col]) | ||
if padded_lengths and col in padded_lengths: | ||
dims = dims.with_dim_max(1, padded_lengths[col]) | ||
|
||
schema[col] = ColumnSchema( | ||
name=cs.name, | ||
tags=cs.tags, | ||
dtype=cs.dtype, | ||
is_list=True, | ||
is_ragged=not sparse_as_dense, | ||
properties=properties, | ||
name=cs.name, tags=cs.tags, dtype=cs.dtype, properties=cs.properties, dims=dims | ||
) | ||
|
||
return schema |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -251,6 +251,7 @@ def __init__( | |
sparse_names, | ||
sparse_max, | ||
sparse_as_dense, | ||
batch_size, | ||
) | ||
|
||
super().__init__( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
from dask.dataframe.utils import meta_nonempty | ||
|
||
from merlin.core.dispatch import DataFrameType, annotate | ||
from merlin.dtypes.shape import DefaultShapes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this going to be a new feature of dtypes in core? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shapes are implemented as a subfield of dtypes in Core, but this isn't really intended to be a feature of dtypes in particular. We might want to hide that implementation detail a bit more thoroughly by adjusting the imports. As far as the defaults go, we just thought it was easier to read and understand than needing to remember which shapes mean what. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expected since there is (or was) an outstanding Core PR that adds it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like it's still open: NVIDIA-Merlin/core#215 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merged now |
||
from merlin.schema import Schema | ||
from nvtabular.ops.operator import ColumnSelector, Operator | ||
|
||
|
@@ -186,10 +187,7 @@ def dependencies(self): | |
def _compute_dtype(self, col_schema, input_schema): | ||
col_schema = super()._compute_dtype(col_schema, input_schema) | ||
|
||
dtype = col_schema.dtype | ||
is_list = col_schema.is_list | ||
|
||
dtypes = { | ||
agg_dtypes = { | ||
"count": numpy.int32, | ||
"nunique": numpy.int32, | ||
"mean": numpy.float32, | ||
|
@@ -199,18 +197,26 @@ def _compute_dtype(self, col_schema, input_schema): | |
"sum": numpy.float32, | ||
} | ||
|
||
is_lists = {"list": True} | ||
agg = self._find_agg(col_schema, input_schema) | ||
dtype = agg_dtypes.get(agg, col_schema.dtype) | ||
|
||
return col_schema.with_dtype(dtype) | ||
|
||
def _compute_shape(self, col_schema, input_schema): | ||
agg_is_lists = {"list": True} | ||
|
||
agg = self._find_agg(col_schema, input_schema) | ||
is_list = agg_is_lists.get(agg, col_schema.is_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case where we have fallback to the second argument of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not actually sure. Are there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From looking at the list of aggregations, I think everything changes the shape, either from list to scalar or from scalar to list. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the default there should actually be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practice it seems that it may not matter whether it's from merlin.io import Dataset
import nvtabular as nvt
import cudf
df = cudf.DataFrame({"a": [1, 1, 2], "b": [[10], [20], [20]]})
workflow = nvt.Workflow(["a", "b"] >> nvt.ops.Groupby(groupby_cols=["a"], aggs=["sum"]))
workflow.fit_transform(Dataset(df)).compute()
# => Raises DataError: All requested aggregations are unsupported. Some of these aggs, like Pandas for example, handles sum across lists as concatenation. import pandas as pd
df = pd.DataFrame({"a": [1, 1, 2], "b": [[10], [20], [20]]})
df.groupby("a").sum()
# =>
b
a
1 [10, 20]
2 [20] or if numpy arrays, then as an element-wise sum df = pd.DataFrame({"a": [1, 1, 2], "b": [np.array([10]), np.array([20]), np.array([20])]})
df.groupby("a").sum()
# =>
b
a
1 [30]
2 [20]
since import cudf
df = cudf.DataFrame({"a": [1, 1, 2], "b": [[10], [20], [20]]})
df.groupby("a").sum()
# => Raises DataError: All requested aggregations are unsupported.
df["b"].sum()
# => Raises TypeError: cannot perform sum with type list There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think If we need groupby aggregations for list columns as a future feature of NVTabular this will need to be revisited. I suppose even if cudf and pandas don't natively support this we could implement this ourselves by extracting the cupy/numpy arrays from the series in our own agg function to handle list column aggregations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently it seems that the only agg that is supported for list columns is from merlin.io import Dataset
import nvtabular as nvt
import cudf
df = cudf.DataFrame({"a": [1, 1, 2], "b": [[10], [20], [20]]})
workflow = nvt.Workflow(["a", "b"] >> nvt.ops.Groupby(groupby_cols=["a"], aggs=["list"]))
workflow.fit_transform(Dataset(df)).compute()
# =>
a b_list
0 1 [[10], [20]]
1 2 [[20]] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's great to know; I really appreciate your thoroughness in testing this out. This probably warrants a further update to the shapes here, I'll open a separate issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tracked here: #1763 |
||
|
||
for col_name in input_schema.column_names: | ||
combined_aggs = _aggs_for_column(col_name, self.conv_aggs) | ||
combined_aggs += _aggs_for_column(col_name, self.list_aggs) | ||
for agg in combined_aggs: | ||
if col_schema.name.endswith(f"{self.name_sep}{agg}"): | ||
dtype = dtypes.get(agg, dtype) | ||
is_list = is_lists.get(agg, is_list) | ||
break | ||
shape = DefaultShapes.LIST if is_list else DefaultShapes.SCALAR | ||
return col_schema.with_shape(shape) | ||
|
||
return col_schema.with_dtype(dtype, is_list=is_list, is_ragged=is_list) | ||
def _find_agg(self, col_schema, input_schema): | ||
input_selector = ColumnSelector(input_schema.column_names) | ||
column_mapping = self.column_mapping(input_selector) | ||
input_column_name = column_mapping[col_schema.name][0] | ||
agg = col_schema.name.replace(input_column_name, "").lstrip(self.name_sep) | ||
return agg | ||
|
||
|
||
def _aggs_for_column(col_name, agg_dict): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,7 +129,7 @@ def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFram | |
|
||
def _compute_dtype(self, col_schema, input_schema): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this method required to be overriden with this change? or could it be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could be removed, but that breaks some of the tests even though the functionality works without it. (There are actually two places like this.) |
||
col_schema = super()._compute_dtype(col_schema, input_schema) | ||
return col_schema.with_dtype(col_schema.dtype, is_list=True, is_ragged=not self.pad) | ||
return col_schema.with_dtype(col_schema.dtype) | ||
|
||
def _compute_properties(self, col_schema, input_schema): | ||
col_schema = super()._compute_properties(col_schema, input_schema) | ||
|
@@ -140,6 +140,17 @@ def _compute_properties(self, col_schema, input_schema): | |
properties["value_count"]["min"] = self.max_elements | ||
return col_schema.with_properties(properties) | ||
|
||
def _compute_shape(self, col_schema, input_schema): | ||
col_schema = super()._compute_shape(col_schema, input_schema) | ||
|
||
min_count, max_count = (0, None) | ||
if self.max_elements != np.iinfo(np.int64).max: | ||
max_count = self.max_elements | ||
if self.pad: | ||
min_count = self.max_elements | ||
|
||
return col_schema.with_shape((None, (min_count, max_count))) | ||
|
||
@property | ||
def output_tags(self): | ||
return [Tags.LIST] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renaming these arguments could be out-of-scope for this PR? Since it may be a breaking change for something that uses this function. It may be clearer to separate this into a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We couldn't really figure out what this function was supposed to do without the renames, so we went for it. The function name starts with an
_
so we've appropriately signaled that external code shouldn't depend on its stability. The two places in the NVTabular that use it call it via argument order instead of specifying the names, so I understand the caution but I think we're okay.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like it will be a safe change, and the names make things clearer :). We may consider removing this functon along with the other loader code to follow-up on the promise here in one of the upcoming releases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And by that point we'll hopefully have updated the dataloader API to a point where this augment_schema function is no longer required either here or the copies in Transformers4Rec and in Merlin Models. The existence of this function suggests to me that there's something missing from the dataloader API as it currently exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the something that's missing is "transforms implemented as operators that provide schema tracking" 😺