diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py index c5725959..33a592ee 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py @@ -207,7 +207,7 @@ def prepare(self, *args, device_placement=None): # Replace the collate_fn in dataloader dataloader.collate_fn = DataCollatorWithFlattening() - return dataloader + return _old_prepare(dataloader) accelerator.prepare = MethodType(prepare, accelerator)