Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T4rec refactor(Part 1) #528

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
623a2bc
add save/load T4Rec model with merlin schema
sararb Oct 21, 2022
ad37cb1
add suport of list outputs
sararb Oct 21, 2022
6dea3fe
add shape property and fix pr comment
sararb Oct 26, 2022
e6f6e58
update shape property with the convention used in systems
sararb Nov 1, 2022
3fb91d2
remove max_sequence_length from in/out schema methods
sararb Nov 1, 2022
c76b416
fix PR comments
sararb Nov 2, 2022
0aaca10
changed outputs and flags
Nov 4, 2022
dc8f254
debugged training
Nov 6, 2022
4fc2c68
eval runs okay
Nov 6, 2022
fa12ad0
separating T4Rec refactoring into 2 parts
Nov 7, 2022
a8a732e
made changes to accomodate dict as metric output
Nov 8, 2022
ed40331
all calculate_metrics are okay now
Nov 9, 2022
7984b79
fixed head and model output when 1 task
Nov 10, 2022
5fb0cd4
No need to pass hf_format flag anymore
Nov 11, 2022
18f3a42
fixed flake8 errors
Nov 11, 2022
f30cfc2
update ignore_masking in sequence.py
nzarif Nov 14, 2022
25fcfa9
call loss instead of forward in PredictionTask.forward
nzarif Nov 14, 2022
83e2b95
dont set default value of testing and training for prediction_step
nzarif Nov 14, 2022
a2d2318
add testing to Trainer model call
nzarif Nov 14, 2022
2e2508d
add save/load T4Rec model with merlin schema
sararb Oct 21, 2022
9e2bc89
add suport of list outputs
sararb Oct 21, 2022
161a162
add shape property and fix pr comment
sararb Oct 26, 2022
1337e94
update shape property with the convention used in systems
sararb Nov 1, 2022
1d662b7
remove max_sequence_length from in/out schema methods
sararb Nov 1, 2022
9b3294d
fix PR comments
sararb Nov 2, 2022
796ff3b
changed outputs and flags
Nov 4, 2022
90628f2
debugged training
Nov 6, 2022
4eb6e9f
eval runs okay
Nov 6, 2022
3586563
separating T4Rec refactoring into 2 parts
Nov 7, 2022
a4f822a
made changes to accomodate dict as metric output
Nov 8, 2022
77ac385
all calculate_metrics are okay now
Nov 9, 2022
c89d1d4
fixed head and model output when 1 task
Nov 10, 2022
a21b8c5
No need to pass hf_format flag anymore
Nov 11, 2022
bfd58bf
fixed flake8 errors
Nov 11, 2022
9b11898
update ignore_masking in sequence.py
nzarif Nov 14, 2022
0ecdae0
call loss instead of forward in PredictionTask.forward
nzarif Nov 14, 2022
2ef23da
dont set default value of testing and training for prediction_step
nzarif Nov 14, 2022
dea6115
add testing to Trainer model call
nzarif Nov 14, 2022
be133d4
Merge branch 't4rec_refactor' of github.com:NVIDIA-Merlin/Transformer…
Nov 14, 2022
fbc8685
addressed Sara' comments
Nov 14, 2022
161004d
Merge branch 'main' into t4rec_refactor
sararb Nov 14, 2022
5fe81db
updated labels when callingcalculate_metrics
Nov 14, 2022
4217a74
minor changes in response to Sara comments
Nov 15, 2022
c63461c
checked with flake8 and black
Nov 15, 2022
e9ec023
10 out of 27 unit tests are okay
Nov 16, 2022
8e1e681
working on 1st test
Nov 16, 2022
f30395f
could not fix first test
Nov 16, 2022
c3e2df9
drafting for accessing later
Nov 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def main():

# Configures the next-item prediction-task
prediction_task = t4r.NextItemPredictionTask(
hf_format=True,
weight_tying=model_args.mf_constrained_embeddings,
softmax_temperature=model_args.softmax_temperature,
metrics=metrics,
Expand Down
2 changes: 1 addition & 1 deletion transformers4rec/torch/features/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def forward(self, inputs, training=True, ignore_masking=True, **kwargs):
if self.projection_module:
outputs = self.projection_module(outputs)

if self.masking and (not ignore_masking or training):
if self.masking and (testing or training):
outputs = self.masking(
outputs, item_ids=self.to_merge["categorical_module"].item_seq, training=training
)
Expand Down
224 changes: 108 additions & 116 deletions transformers4rec/torch/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@ def build(
metric.to(device)
self.built = True

def forward(self, inputs, **kwargs):
def forward(self, inputs, targets=None, training=False, testing=False, **kwargs):
if isinstance(targets, dict):
if self.target_name:
targets = targets[self.target_name]
else:
raise ValueError(
"target_name can not be None, specify a name."
)

x = inputs

if len(x.size()) == 3 and self.summary_type:
Expand All @@ -157,6 +165,12 @@ def forward(self, inputs, **kwargs):
if self.pre:
x = self.pre(x)

if training or testing:
# add support of computing the loss inside the forward
# and return a dictionary as standard output
loss = self.loss(x, targets)
return {"loss": loss, "labels": targets, "predictions": x}

return x

@property
Expand All @@ -174,45 +188,29 @@ def child_name(self, name):
def set_metrics(self, metrics):
self.metrics = torch.nn.ModuleList(metrics)

def compute_loss(
self,
inputs: Union[torch.Tensor, TabularData],
targets: Union[torch.Tensor, TabularData],
compute_metrics: bool = True,
training: bool = False,
ignore_masking: bool = False,
**kwargs,
) -> torch.Tensor:
if isinstance(targets, dict) and self.target_name:
targets = targets[self.target_name]

predictions = self(inputs, training=training, ignore_masking=ignore_masking)
loss = self.loss(predictions, targets)

if compute_metrics:
self.calculate_metrics(
predictions, targets, mode="train", ignore_masking=ignore_masking, forward=False
)

return loss

return loss

def calculate_metrics( # type: ignore
self,
predictions: Union[torch.Tensor, TabularData],
targets: Union[torch.Tensor, TabularData],
mode: str = "val",
forward: bool = True,
ignore_masking: bool = False,
training: bool = False,
testing: bool = True,
**kwargs,
) -> Dict[str, torch.Tensor]:
if isinstance(targets, dict) and self.target_name:
targets = targets[self.target_name]
if isinstance(targets, dict):
if self.target_name:
targets = targets[self.target_name]
else:
raise ValueError(
"target_name can not be None, specify a name."
)

outputs = {}
if forward:
predictions = self(predictions, ignore_masking=ignore_masking)
fwd_output = self(predictions, training=training, testing=testing)
predictions=fwd_output["predictions"]
targets=fwd_output["labels"]
predictions = self.forward_to_prediction_fn(cast(torch.Tensor, predictions))

from .prediction_task import BinaryClassificationTask
Expand Down Expand Up @@ -386,70 +384,59 @@ def pop_labels(self, inputs: TabularData) -> TabularData:
def forward(
self,
body_outputs: Union[torch.Tensor, TabularData],
training: bool = True,
targets: Union[torch.Tensor, TabularData] = None,
training: bool = False,
testing: bool = False,
call_body: bool = False,
always_output_dict: bool = False,
ignore_masking: bool = True,
**kwargs,
) -> Union[torch.Tensor, TabularData]:
outputs = {}

if call_body:
body_outputs = self.body(body_outputs, training=training, ignore_masking=ignore_masking)

for name, task in self.prediction_task_dict.items():
outputs[name] = task(
body_outputs, ignore_masking=ignore_masking, training=training, **kwargs
body_outputs = self.body(
body_outputs, training=training, testing=testing, targets=targets
)

if len(outputs) == 1 and not always_output_dict:
return outputs[list(outputs.keys())[0]]
if training or testing:
losses = []
labels = {}
predictions = {}
for name, task in self.prediction_task_dict.items():
task_output = task(
body_outputs, training=training, testing=testing, targets=targets, **kwargs
)
labels[name] = task_output["labels"]
predictions[name] = task_output["predictions"]
losses.append(task_output["loss"] * self._task_weights[name])
loss_tensor = torch.stack(losses)
loss = getattr(loss_tensor, self.loss_reduction)()
""" if len(labels)==1:
labels=list(labels.values())[0]
predictions=list(predictions.values())[0] """
outputs = {"loss": loss, "labels": labels, "predictions": predictions}
else:
for name, task in self.prediction_task_dict.items():
outputs[name] = task(
body_outputs, training=training, testing=testing, targets=targets, **kwargs
)

return outputs

def compute_loss( # type: ignore
self,
body_outputs: Union[torch.Tensor, TabularData],
targets: Union[torch.Tensor, TabularData],
training: bool = True,
compute_metrics: bool = True,
call_body: bool = False,
ignore_masking: bool = True,
**kwargs,
) -> torch.Tensor:
losses = []

if call_body:
body_outputs = self.body(body_outputs, training=training, ignore_masking=ignore_masking)

for name, task in self.prediction_task_dict.items():
loss = task.compute_loss(
body_outputs,
targets,
compute_metrics=compute_metrics,
ignore_masking=ignore_masking,
**kwargs,
)
losses.append(loss * self._task_weights[name])

loss_tensor = torch.stack(losses)

return getattr(loss_tensor, self.loss_reduction)()

def calculate_metrics( # type: ignore
self,
body_outputs: Union[torch.Tensor, TabularData],
targets: Union[torch.Tensor, TabularData],
mode: str = "val",
forward=True,
call_body=False,
ignore_masking=True,
training=False,
testing=True,
**kwargs,
) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]]:
metrics = {}

if call_body:
body_outputs = self.body(body_outputs, training=False, ignore_masking=ignore_masking)
body_outputs = self.body(body_outputs, training=training, testing=testing)

for name, task in self.prediction_task_dict.items():
metrics.update(
Expand All @@ -458,7 +445,8 @@ def calculate_metrics( # type: ignore
targets,
mode=mode,
forward=forward,
ignore_masking=ignore_masking,
training=training,
testing=testing,
**kwargs,
)
)
Expand Down Expand Up @@ -504,7 +492,6 @@ def __init__(
head_reduction: str = "mean",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
name: str = None,
hf_format: bool = True,
):
"""Model class that can aggregate one or multiple heads.

Expand All @@ -520,13 +507,6 @@ def __init__(
Optimizer-class to use during fitting
name: str, optional
Name of the model.
hf_format: bool, optional
This parameter is specific to NextItemPredictionTask class and controls the format of
the output returned by the task. If `True`, the task returns a dictionary
with three tensors: loss, predictions, labels. Otherwise, it returns the tensor of
`predictions` scores.
Usually, hf_format is set to True during training and False during inference
By default True.
"""
if head_weights:
if not isinstance(head_weights, list):
Expand All @@ -543,52 +523,64 @@ def __init__(
self.head_weights = head_weights or [1.0] * len(head)
self.head_reduction = head_reduction
self.optimizer = optimizer
self.hf_format = hf_format

def forward(self, inputs: TensorOrTabularData, training=True, **kwargs):
def forward(
self, inputs: TensorOrTabularData, training=True, testing=False, targets=None, **kwargs
):
# TODO: Optimize this
outputs = {}
for head in self.heads:
outputs.update(
head(inputs, call_body=True, training=training, always_output_dict=True, **kwargs)
)

if len(outputs) == 1:
outputs = outputs[list(outputs.keys())[0]]
if isinstance(outputs, dict):
# next-item prediction task with `hf_format = True`` returns a dictionary
# with three tensors: loss, predictions, labels. This needed for training
# and evaluation using the Trainer class.
# At inference, we just need the predictions tensors.
# TODO: We are simplifying the logic around `hf_format` in the multi-gpu
# support work.
if not self.hf_format:
return outputs["predictions"]
return outputs

return outputs

def compute_loss(self, inputs, targets, compute_metrics=True, **kwargs) -> torch.Tensor:
losses = []

for i, head in enumerate(self.heads):
loss = head.compute_loss(
inputs, targets, call_body=True, compute_metrics=compute_metrics, **kwargs
)
losses.append(loss * self.head_weights[i])
if training or testing:
losses = []
labels = {}
predictions = {}
for i, head in enumerate(self.heads):
head_output = head(
inputs, call_body=True, training=training, testing=testing, **kwargs
)
labels.update(head_output["labels"])
predictions.update(head_output["predictions"])
losses.append(head_output["loss"] * self.head_weights[i])
loss_tensor = torch.stack(losses)
loss = getattr(loss_tensor, self.head_reduction)()
if len(labels) == 1:
labels = list(labels.values())[0]
predictions = list(predictions.values())[0]
outputs = {"loss": loss, "labels": labels, "predictions": predictions}

loss_tensor = torch.stack(losses)
else:
for head in self.heads:
outputs.update(
head(inputs, call_body=True, training=training, testing=testing, **kwargs)
)
if len(outputs) == 1:
nzarif marked this conversation as resolved.
Show resolved Hide resolved
return outputs[list(outputs.values())[0]]

return getattr(loss_tensor, self.head_reduction)()
return outputs

def calculate_metrics( # type: ignore
self, inputs, targets, mode="val", call_body=True, forward=True, **kwargs
self,
inputs,
targets,
mode="val",
call_body=True,
training=False,
testing=True,
forward=True,
**kwargs,
) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]]:
outputs = {}
for head in self.heads:
outputs.update(
head.calculate_metrics(
inputs, targets, mode=mode, call_body=call_body, forward=forward, **kwargs
inputs,
targets,
mode=mode,
call_body=call_body,
training=training,
testing=testing,
forward=forward,
**kwargs,
)
)

Expand Down Expand Up @@ -619,7 +611,7 @@ def forward(self, inputs, *args, **kwargs):
return self.parent(inputs, *args, **kwargs)

def training_step(self, batch, batch_idx):
loss = self.parent.compute_loss(*batch)
loss = self.parent(*batch, training=True, testing=False)
nzarif marked this conversation as resolved.
Show resolved Hide resolved
self.log("train_loss", loss)

return loss
Expand Down Expand Up @@ -660,9 +652,9 @@ def fit(
for batch_idx, (x, y) in batch_iterator:
if amp:
with torch.cuda.amp.autocast():
loss = self.compute_loss(x, y)
loss = self(x, targets=y, training=True, testing=False)["loss"]
else:
loss = self.compute_loss(x, y)
loss = self(x, targets=y, training=True, testing=False)["loss"]

losses.append(float(loss))

Expand All @@ -689,7 +681,7 @@ def evaluate(self, dataloader, verbose=True, mode="eval"):
batch_iterator = tqdm(batch_iterator)
self.reset_metrics()
for batch_idx, (x, y) in batch_iterator:
self.calculate_metrics(x, y, mode=mode)
self.calculate_metrics(x, y, mode=mode, training=False, testing=True)

return self.compute_metrics(mode=mode)

Expand Down
Loading