Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Feb 8, 2023
1 parent d9573e9 commit 24319ef
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lighter/callbacks/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytorch_lightning import Callback, Trainer

from lighter import LighterSystem
from lighter.callbacks.utils import concatenate, parse_data, preprocess_image
from lighter.callbacks.utils import parse_data, preprocess_image, structure_preserving_concatenate


class LighterBaseWriter(ABC, Callback):
Expand Down Expand Up @@ -112,7 +112,7 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem, outpu
outputs = outputs[0]
# Concatenate/flatten so that each output corresponds to its index.
indices = list(itertools.chain(*indices))
outputs = concatenate(outputs)
outputs = structure_preserving_concatenate(outputs)
self._on_batch_or_epoch_end(outputs, indices)

def _on_batch_or_epoch_end(self, outputs, indices):
Expand Down

0 comments on commit 24319ef

Please sign in to comment.