From 5a0a5ba99020ac4b163807de287e048c435b614d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 13 Oct 2023 16:56:28 +0800 Subject: [PATCH] olss update --- diffusion/olss_scheduler/README.md | 2 +- diffusion/olss_scheduler/olss.py | 29 ++++++++++++++++------------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/diffusion/olss_scheduler/README.md b/diffusion/olss_scheduler/README.md index 061fedb..7c39340 100644 --- a/diffusion/olss_scheduler/README.md +++ b/diffusion/olss_scheduler/README.md @@ -35,7 +35,7 @@ Seed: 0 ## Usage 使用 ``` -pip install diffusers torch +pip install diffusers==0.21.3 torch ``` We provide a demo here: diff --git a/diffusion/olss_scheduler/olss.py b/diffusion/olss_scheduler/olss.py index 1d7eafd..a9d3fab 100644 --- a/diffusion/olss_scheduler/olss.py +++ b/diffusion/olss_scheduler/olss.py @@ -1,17 +1,7 @@ -from typing import Union, Optional import torch -from dataclasses import dataclass -from diffusers.utils import BaseOutput from tqdm import tqdm -@dataclass -class OLSSSchedulerOutput(BaseOutput): - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - class OLSSSchedulerModel(torch.nn.Module): def __init__(self, wx, we): @@ -38,11 +28,22 @@ def __init__(self, timesteps, model): self.init_noise_sigma = 1.0 self.order = 1 - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + @staticmethod + def load(path): + timesteps, wx, we = torch.load(path, map_location="cpu") + model = OLSSSchedulerModel(wx, we) + return OLSSScheduler(timesteps, model) + + def save(self, path): + timesteps, wx, we = self.timesteps, self.model.wx, self.model.we + torch.save((timesteps, wx, we), path) + + def set_timesteps(self, num_inference_steps, device = "cuda"): self.xT = None self.e_prev = [] self.t_prev = -1 self.model = self.model.to(device) + self.timesteps = self.timesteps.to(device) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs): return sample @@ -53,6 +54,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + *args, **kwargs ): t = self.timesteps.tolist().index(timestep) assert self.t_prev==-1 or t==self.t_prev+1 @@ -66,7 +68,7 @@ def step( self.t_prev = -1 else: self.t_prev = t - return OLSSSchedulerOutput(prev_sample=x, pred_original_sample=None) + return (x,) class OLSSSolver: @@ -213,7 +215,7 @@ def step(self, model_output, timestep, sample, **kwargs): self.catch_x_[timestep] = [] self.catch_x[timestep].append(sample.clone().detach().cpu()) self.catch_e[timestep].append(model_output.clone().detach().cpu()) - self.catch_x_[timestep].append(result.prev_sample.clone().detach().cpu()) + self.catch_x_[timestep].append(result[0].clone().detach().cpu()) return result else: result = self.olss_scheduler.step(model_output, timestep, sample, **kwargs) @@ -249,3 +251,4 @@ def prepare_olss(self, num_accelerate_steps): num_accelerate_steps, t_path, x_path, e_path) self.olss_model = OLSSSchedulerModel(wx, we) self.olss_scheduler = OLSSScheduler(timesteps, self.olss_model) + \ No newline at end of file