Skip to content

Commit

Permalink
SoftmaxSampling should produce the same output dtype as the input (#154)
Browse files Browse the repository at this point in the history
* softmax sampling should keep input col dtype

* rename old variable
  • Loading branch information
nv-alaiacano authored Jul 29, 2022
1 parent f1e2bf5 commit 572dbf1
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
17 changes: 13 additions & 4 deletions merlin/systems/dag/ops/softmax_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def compute_output_schema(
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
) -> Schema:
"""Describe the operator's outputs"""
return Schema([ColumnSchema("ordered_ids", dtype=np.int32, is_list=True, is_ragged=True)])
return Schema(
[
ColumnSchema(
"ordered_ids",
dtype=input_schema.get(self._input_col_name).dtype,
is_list=True,
is_ragged=True,
)
]
)

def transform(self, df: InferenceDataFrame) -> InferenceDataFrame:
"""Transform the dataframe by applying this operator to the set of input columns"""
Expand Down Expand Up @@ -121,7 +130,7 @@ def transform(self, df: InferenceDataFrame) -> InferenceDataFrame:

# This is just bookkeeping to produce the final ordered list of recs
sorted_indices = np.argsort(exponentials)
topk_movie_ids = candidate_ids[sorted_indices][: self.topk]
ordered_movie_ids = topk_movie_ids.reshape(1, -1).T
topk_item_ids = candidate_ids[sorted_indices][: self.topk]
ordered_item_ids = topk_item_ids.reshape(1, -1).T

return InferenceDataFrame({"ordered_ids": ordered_movie_ids})
return InferenceDataFrame({"ordered_ids": ordered_item_ids})
Empty file.
29 changes: 29 additions & 0 deletions tests/unit/systems/dag/ops/test_softmax_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pandas as pd

from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ops.softmax_sampling import SoftmaxSampling
from nvtabular import ColumnSelector


def test_softmax_output_dtype_keeps_input_dtype():
# We expect the method to:
# * change the output name to `ordered_ids`
# * change is_list to True
# * change is_ragged to True
# * Not change the output dtype

s = SoftmaxSampling("rel_col", _input_col="input_col")

input_col_schema = Schema(
[
ColumnSchema(name="rel_col"),
ColumnSchema(name="input_col", dtype=pd.StringDtype, is_list=False, is_ragged=False),
]
)

expected = Schema(
[ColumnSchema(name="ordered_ids", dtype=pd.StringDtype, is_list=True, is_ragged=True)]
)
actual = s.compute_output_schema(input_col_schema, ColumnSelector(["ordered_ids"]))

assert actual == expected

0 comments on commit 572dbf1

Please sign in to comment.