Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring part1- flags modification #543

Merged
merged 12 commits into from
Nov 22, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@
")\n",
"\n",
"# Define Next item prediction-task \n",
"prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)\n",
"prediction_task = tr.NextItemPredictionTask(weight_tying=True)\n",
"\n",
"# Define the config of the XLNet Transformer architecture\n",
"transformer_config = tr.XLNetConfig.build(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
")\n",
"\n",
"# Define Next item prediction-task \n",
"prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)\n",
"prediction_task = tr.NextItemPredictionTask(weight_tying=True)\n",
"\n",
"# Define the config of the XLNet Transformer architecture\n",
"transformer_config = tr.XLNetConfig.build(\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@
"# Define a head related to next item prediction task \n",
"head = tr.Head(\n",
" body,\n",
" tr.NextItemPredictionTask(weight_tying=True, hf_format=True, \n",
" tr.NextItemPredictionTask(weight_tying=True, \n",
" metrics=metrics),\n",
" inputs=inputs,\n",
")\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def main():

# Configures the next-item prediction-task
prediction_task = t4r.NextItemPredictionTask(
hf_format=True,
weight_tying=model_args.mf_constrained_embeddings,
softmax_temperature=model_args.softmax_temperature,
metrics=metrics,
Expand Down
10 changes: 5 additions & 5 deletions examples/tutorial/03-Session-based-recsys.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def torch_yoochoose_next_item_prediction_model(torch_yoochoose_tabular_transform
tr.MLPBlock([64]),
tr.TransformerBlock(transformer_config, masking=inputs.masking),
)
model = tr.NextItemPredictionTask(weight_tying=True, hf_format=True).to_model(body, inputs)
model = tr.NextItemPredictionTask(weight_tying=True).to_model(body, inputs)
return model


Expand Down
10 changes: 2 additions & 8 deletions tests/torch/features/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,10 @@ def test_sequential_tabular_features_ignore_masking(yoochoose_schema, torch_yooc
input_module._masking = CausalLanguageModeling(hidden_size=100)

output_ignore_masking = (
input_module(torch_yoochoose_like, training=False, ignore_masking=True)
.detach()
.cpu()
.numpy()
input_module(torch_yoochoose_like, training=False, testing=False).detach().cpu().numpy()
)
output_masking = (
input_module(torch_yoochoose_like, training=False, ignore_masking=False)
.detach()
.cpu()
.numpy()
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
)

assert np.allclose(output_wo_masking, output_ignore_masking, rtol=1e-04, atol=1e-08)
Expand Down
17 changes: 9 additions & 8 deletions tests/torch/model/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_item_prediction_loss_and_metrics(
body = tr.SequentialBlock(input_module, tr.MLPBlock([64]))
head = tr.Head(body, tr.NextItemPredictionTask(weight_tying=weight_tying), inputs=input_module)

body_outputs = body(torch_yoochoose_like, ignore_masking=False)
body_outputs = body(torch_yoochoose_like, testing=True)

trg_flat = input_module.masking.masked_targets.flatten()
non_pad_mask = trg_flat != input_module.masking.padding_idx
Expand All @@ -143,10 +143,11 @@ def test_item_prediction_loss_and_metrics(
loss = head.prediction_task_dict["next-item"].compute_loss(
inputs=body_outputs,
targets=labels_all,
testing=True,
)

metrics = head.prediction_task_dict["next-item"].calculate_metrics(
predictions=body_outputs, targets=labels_all
predictions=body_outputs, targets=labels_all, testing=True
)
assert all(len(m) == 2 for m in metrics.values())
assert loss != 0
Expand All @@ -161,11 +162,11 @@ def test_item_prediction_HF_output(
body = tr.SequentialBlock(input_module, tr.MLPBlock([64]))
head = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True, hf_format=True),
tr.NextItemPredictionTask(weight_tying=True),
inputs=input_module,
)

outputs = head(body(torch_yoochoose_like))
outputs = head(body(torch_yoochoose_like, training=True), training=True)

assert isinstance(outputs, dict)
assert [
Expand Down Expand Up @@ -209,11 +210,11 @@ def test_item_prediction_head_with_input_size(
)
head = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True, hf_format=True),
tr.NextItemPredictionTask(weight_tying=True),
inputs=input_module,
)

outputs = head(body(torch_yoochoose_like))
outputs = head(body(torch_yoochoose_like, training=True), training=True)

assert outputs

Expand All @@ -231,11 +232,11 @@ def test_item_prediction_with_rnn(
)
head = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True, hf_format=True),
tr.NextItemPredictionTask(weight_tying=True),
inputs=input_module,
)

outputs = head(body(torch_yoochoose_like))
outputs = head(body(torch_yoochoose_like, training=True), training=True)

assert isinstance(outputs, dict)
assert list(outputs.keys()) == [
Expand Down
30 changes: 16 additions & 14 deletions tests/torch/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def test_sequential_prediction_model(

head_1 = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True, hf_format=True),
tr.NextItemPredictionTask(weight_tying=True),
inputs=inputs,
)
head_2 = task("target", summary_type="mean").to_head(body, inputs)

model = tr.Model(head_1, head_2)
output = model(torch_yoochoose_like)
output = model(torch_yoochoose_like, training=True)

assert isinstance(output, dict)
assert len(list(output.keys())) == 2
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_set_model_to_device(
assert next(model.parameters()).device.type == device

inputs = {k: v.to(device) for k, v in torch_yoochoose_like.items()}
assert model(inputs)
assert model(inputs, training=True)


@pytest.mark.parametrize("masking", ["causal", "mlm", "plm", "rtd"])
Expand All @@ -261,11 +261,15 @@ def test_eval_metrics_with_masking(torch_yoochoose_like, yoochoose_schema, maski
d_output=64,
masking=masking,
)
task = tr.NextItemPredictionTask(hf_format=True)
task = tr.NextItemPredictionTask()
model = transformer_config.to_torch_model(input_module, task)
out = model(torch_yoochoose_like)
out = model(torch_yoochoose_like, training=True)
result = model.calculate_metrics(
inputs=out["predictions"], targets=out["labels"], call_body=False, forward=False
inputs=out["predictions"],
targets=out["labels"],
call_body=False,
forward=False,
testing=True,
)
assert result is not None

Expand All @@ -280,9 +284,9 @@ def test_with_d_model_different_from_item_dim(torch_yoochoose_like, yoochoose_sc
d_output=d_model,
masking="mlm",
)
task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
task = tr.NextItemPredictionTask(weight_tying=True)
model = transformer_config.to_torch_model(input_module, task)
assert model(torch_yoochoose_like)
assert model(torch_yoochoose_like, training=True)


@pytest.mark.parametrize("masking", ["causal", "mlm", "plm", "rtd"])
Expand All @@ -293,13 +297,13 @@ def test_output_shape_mode_eval(torch_yoochoose_like, yoochoose_schema, masking)
d_output=64,
masking=masking,
)
prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
prediction_task = tr.NextItemPredictionTask(weight_tying=True)
transformer_config = tconf.XLNetConfig.build(
d_model=64, n_head=8, n_layer=2, total_seq_length=20
)
model = transformer_config.to_torch_model(input_module, prediction_task)

out = model(torch_yoochoose_like, training=False)
out = model(torch_yoochoose_like, training=False, testing=True)
assert out["predictions"].shape[0] == torch_yoochoose_like["item_id/list"].size(0)


Expand All @@ -308,18 +312,16 @@ def test_save_next_item_prediction_model(
):
inputs = torch_yoochoose_tabular_transformer_features
transformer_config = tconf.XLNetConfig.build(100, 4, 2, 20)
task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
task = tr.NextItemPredictionTask(weight_tying=True)
model = transformer_config.to_torch_model(inputs, task)
output = model(torch_yoochoose_like, training=False)
output = model(torch_yoochoose_like, training=False, testing=True)
assert isinstance(output, dict)

with tempfile.TemporaryDirectory() as tmpdir:
model.save(tmpdir)
assert "t4rec_model_class.pkl" in os.listdir(tmpdir)
loaded_model = model.load(tmpdir)
# deactivate the hf_format model to get the tensor of predictions as output
# instead of a dictionary of three tensors `loss, labels, and predictions`
loaded_model.hf_format = False

output = loaded_model(torch_yoochoose_like, training=False)
assert isinstance(output, torch.Tensor)
Expand Down
8 changes: 5 additions & 3 deletions tests/torch/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def test_item_prediction_with_label_smoothing_ce_loss(
body, tr.NextItemPredictionTask(weight_tying=True, loss=custom_loss), inputs=input_module
)

body_outputs = body(torch_yoochoose_like, ignore_masking=False)
body_outputs = body(torch_yoochoose_like, testing=True)

trg_flat = input_module.masking.masked_targets.flatten()
non_pad_mask = trg_flat != input_module.masking.padding_idx
labels_all = pytorch.masked_select(trg_flat, non_pad_mask)
predictions = head(body_outputs, ignore_masking=False)
head_output = head(body_outputs, testing=True)

loss = head.prediction_task_dict["next-item"].compute_loss(
inputs=body_outputs,
Expand All @@ -34,6 +34,8 @@ def test_item_prediction_with_label_smoothing_ce_loss(
n_classes = 51997
manuall_loss = pytorch.nn.NLLLoss(reduction="mean")
target_with_smoothing = labels_all * (1 - label_smoothing) + label_smoothing / n_classes
manual_output_loss = manuall_loss(predictions, target_with_smoothing.to(pytorch.long))
manual_output_loss = manuall_loss(
head_output["predictions"], target_with_smoothing.to(pytorch.long)
)

assert np.allclose(manual_output_loss.detach().numpy(), loss.detach().numpy(), rtol=1e-3)
6 changes: 3 additions & 3 deletions tests/torch/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_torchsciprt_not_strict(torch_yoochoose_like, yoochoose_schema):
d_output=64,
masking="causal",
)
prediction_task = tr.NextItemPredictionTask(hf_format=True, weight_tying=True)
prediction_task = tr.NextItemPredictionTask(weight_tying=True)
transformer_config = tconf.XLNetConfig.build(
d_model=64, n_head=8, n_layer=2, total_seq_length=20
)
Expand All @@ -27,6 +27,6 @@ def test_torchsciprt_not_strict(torch_yoochoose_like, yoochoose_schema):
traced_model = pytorch.jit.trace(model, torch_yoochoose_like, strict=False)
assert isinstance(traced_model, pytorch.jit.TopLevelTracedModule)
assert pytorch.allclose(
model(torch_yoochoose_like)["predictions"],
traced_model(torch_yoochoose_like)["predictions"],
model(torch_yoochoose_like),
traced_model(torch_yoochoose_like),
)
6 changes: 3 additions & 3 deletions transformers4rec/torch/block/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __rshift__(self, other):
# pylint: disable=arguments-out-of-order
return right_shift_block(other, self)

def forward(self, input, training=True, ignore_masking=True, **kwargs):
def forward(self, input, training=False, testing=False, **kwargs):
# from transformers4rec.torch import TabularSequenceFeatures

for i, module in enumerate(self):
Expand All @@ -148,8 +148,8 @@ def forward(self, input, training=True, ignore_masking=True, **kwargs):
input = module(input, **filtered_kwargs)

elif "training" in inspect.signature(module.forward).parameters:
if "ignore_masking" in inspect.signature(module.forward).parameters:
input = module(input, training=training, ignore_masking=ignore_masking)
if "testing" in inspect.signature(module.forward).parameters:
input = module(input, training=training, testing=testing)
else:
input = module(input, training=training)
else:
Expand Down
6 changes: 2 additions & 4 deletions transformers4rec/torch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ def __init__(
elif fusion_aggregation == "concat":
self.last_dim = hidden_dim + post_context_last_dim

def forward(self, inputs, training=False, ignore_masking=True, **kwargs):
seq_rep = self.sequential_module(
inputs, training=training, ignore_masking=ignore_masking, **kwargs
)
def forward(self, inputs, training=False, testing=False, **kwargs):
seq_rep = self.sequential_module(inputs, training=training, testing=testing, **kwargs)
context_rep = self.post_context_module(inputs, training=training)

if len(context_rep.size()) == 2:
Expand Down
4 changes: 2 additions & 2 deletions transformers4rec/torch/features/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def item_embedding_table(self) -> Optional[torch.nn.Module]:

return None

def forward(self, inputs, training=True, ignore_masking=True, **kwargs):
def forward(self, inputs, training=False, testing=False, **kwargs):
outputs = super(TabularSequenceFeatures, self).forward(inputs)

if self.masking or self.projection_module:
Expand All @@ -256,7 +256,7 @@ def forward(self, inputs, training=True, ignore_masking=True, **kwargs):
if self.projection_module:
outputs = self.projection_module(outputs)

if self.masking and (not ignore_masking or training):
if self.masking and (training or testing):
outputs = self.masking(
outputs, item_ids=self.to_merge["categorical_module"].item_seq, training=training
)
Expand Down
4 changes: 3 additions & 1 deletion transformers4rec/torch/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def predict_all(self, item_ids: torch.Tensor) -> MaskingInfo:

return MaskingInfo(mask_labels, labels)

def forward(self, inputs: torch.Tensor, item_ids: torch.Tensor, training=False) -> torch.Tensor:
def forward(
self, inputs: torch.Tensor, item_ids: torch.Tensor, training=False, testing=False
) -> torch.Tensor:
_ = self.compute_masked_targets(item_ids=item_ids, training=training)
if self.mask_schema is None:
raise ValueError("`mask_schema must be set.`")
Expand Down
Loading