diff --git a/discord_tron_client/classes/app_config.py b/discord_tron_client/classes/app_config.py index 04e9551d..f79e0b45 100644 --- a/discord_tron_client/classes/app_config.py +++ b/discord_tron_client/classes/app_config.py @@ -282,4 +282,8 @@ def bark_subsystem_type(self): def enable_compel(self): return self.config.get("use_compel_prompt_weighting", True) def enable_compile(self): - return self.config.get('enable_torch_compile', True) \ No newline at end of file + return self.config.get('enable_torch_compile', True) + def enable_cpu_offload(self): + return self.config.get('enable_cpu_offload', True) + def maximum_batch_size(self): + return max(self.config.get('maximum_batch_size', 4), 1) \ No newline at end of file diff --git a/discord_tron_client/classes/image_manipulation/pipeline.py b/discord_tron_client/classes/image_manipulation/pipeline.py index 58ec6451..dfae77e9 100644 --- a/discord_tron_client/classes/image_manipulation/pipeline.py +++ b/discord_tron_client/classes/image_manipulation/pipeline.py @@ -203,11 +203,7 @@ def _run_pipeline( ): original_stderr = sys.stderr sys.stderr = self.tqdm_capture - batch_size = 4 - if hardware.should_offload(): - batch_size = 2 - if hardware.should_sequential_offload(): - batch_size = 1 + batch_size = self.config.maximum_batch_size() try: alt_weight_algorithm = user_config.get("alt_weight_algorithm", False) use_latent_result = user_config.get('latent_refiner', True)