Skip to content

Commit

Permalink
Merge pull request #341 from alibaba/olss_update
Browse files Browse the repository at this point in the history
olss update
  • Loading branch information
chywang authored Oct 20, 2023
2 parents 6808d47 + 5a0a5ba commit cd91f44
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion diffusion/olss_scheduler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Seed: 0
## Usage 使用

```
pip install diffusers torch
pip install diffusers==0.21.3 torch
```

We provide a demo here:
Expand Down
29 changes: 16 additions & 13 deletions diffusion/olss_scheduler/olss.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
from typing import Union, Optional
import torch
from dataclasses import dataclass
from diffusers.utils import BaseOutput
from tqdm import tqdm


@dataclass
class OLSSSchedulerOutput(BaseOutput):

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class OLSSSchedulerModel(torch.nn.Module):

def __init__(self, wx, we):
Expand All @@ -38,11 +28,22 @@ def __init__(self, timesteps, model):
self.init_noise_sigma = 1.0
self.order = 1

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@staticmethod
def load(path):
timesteps, wx, we = torch.load(path, map_location="cpu")
model = OLSSSchedulerModel(wx, we)
return OLSSScheduler(timesteps, model)

def save(self, path):
timesteps, wx, we = self.timesteps, self.model.wx, self.model.we
torch.save((timesteps, wx, we), path)

def set_timesteps(self, num_inference_steps, device = "cuda"):
self.xT = None
self.e_prev = []
self.t_prev = -1
self.model = self.model.to(device)
self.timesteps = self.timesteps.to(device)

def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs):
return sample
Expand All @@ -53,6 +54,7 @@ def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
*args, **kwargs
):
t = self.timesteps.tolist().index(timestep)
assert self.t_prev==-1 or t==self.t_prev+1
Expand All @@ -66,7 +68,7 @@ def step(
self.t_prev = -1
else:
self.t_prev = t
return OLSSSchedulerOutput(prev_sample=x, pred_original_sample=None)
return (x,)


class OLSSSolver:
Expand Down Expand Up @@ -213,7 +215,7 @@ def step(self, model_output, timestep, sample, **kwargs):
self.catch_x_[timestep] = []
self.catch_x[timestep].append(sample.clone().detach().cpu())
self.catch_e[timestep].append(model_output.clone().detach().cpu())
self.catch_x_[timestep].append(result.prev_sample.clone().detach().cpu())
self.catch_x_[timestep].append(result[0].clone().detach().cpu())
return result
else:
result = self.olss_scheduler.step(model_output, timestep, sample, **kwargs)
Expand Down Expand Up @@ -249,3 +251,4 @@ def prepare_olss(self, num_accelerate_steps):
num_accelerate_steps, t_path, x_path, e_path)
self.olss_model = OLSSSchedulerModel(wx, we)
self.olss_scheduler = OLSSScheduler(timesteps, self.olss_model)

0 comments on commit cd91f44

Please sign in to comment.