Skip to content

Commit

Permalink
Add BetaForAlphaBar, ConvertModelOutput, SetTimesteps, and constructo…
Browse files Browse the repository at this point in the history
…r for DPMSolverMultistepScheduler
  • Loading branch information
joey12300 committed Nov 28, 2022
1 parent d030719 commit 9623de3
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 6 deletions.
25 changes: 25 additions & 0 deletions examples/multimodal/stable_diffusion/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

PROJECT(main C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)

option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
set(THIRD_LIBS "")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)

include_directories(${FASTDEPLOY_INCS})

add_executable(main ${PROJECT_SOURCE_DIR}/main.cc ${PROJECT_SOURCE_DIR}/dpm_solver_multistep_scheduler.cc)
target_link_libraries(main ${FASTDEPLOY_LIBS} ${THIRD_LIBS})
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dpm_solver_multistep_scheduler.h"
#include "fastdeploy/core/fd_scalar.h"
#include "fastdeploy/function/functions.h"
#include <algorithm>
#include <cmath>

namespace fastdeploy {

void DPMSolverMultistepScheduler::BetaForAlphaBar(FDTensor* out,
int num_diffusion_timesteps,
float max_beta) {
auto alpha_bar = [](float time_step) -> float {
constexpr float pi = 3.14159265358979323846;
return std::pow(std::cos((time_step + 0.008) / 1.008 * pi / 2), 2);
};
std::vector<FDTensor> betas;
for (int i = 0; i < num_diffusion_timesteps; ++i) {
float t1 = i / num_diffusion_timesteps;
float t2 = (i + 1) / num_diffusion_timesteps;
float beta_val = (std::min)(1 - alpha_bar(t1) / alpha_bar(t2), max_beta);
betas.emplace_back(Scalar(beta_val));
}
function::Concat(betas, out);
}

DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
int num_train_timesteps, float beta_start, float beta_end,
const std::string& beta_schedule, const std::vector<float>& trained_betas,
int solver_order, bool predict_epsilon, bool thresholding,
float dynamic_thresholding_ratio, float sample_max_value,
const std::string& algorithm_type, const std::string& solver_type,
bool lower_order_final)
: num_train_timesteps_(num_train_timesteps), beta_start_(beta_start),
beta_end_(beta_end), beta_schedule_(beta_schedule),
solver_order_(solver_order), predict_epsilon_(predict_epsilon),
thresholding_(thresholding),
dynamic_thresholding_ratio_(dynamic_thresholding_ratio),
sample_max_value_(sample_max_value), algorithm_type_(algorithm_type),
solver_type_(solver_type), lower_order_final_(lower_order_final) {
int beta_size = trained_betas.size();
if (beta_size > 0) {
betas_.Allocate({beta_size}, FDDataType::FP32);
std::copy(trained_betas.data(), trained_betas.data() + beta_size,
reinterpret_cast<float*>(betas_.Data()));
} else if (beta_schedule == "linear") {
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
FDDataType::FP32);
} else if (beta_schedule == "scaled_linear") {
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
FDDataType::FP32);
betas_ = betas_ * betas_;
} else if (beta_schedule == "squaredcos_cap_v2") {
BetaForAlphaBar(&betas_, num_train_timesteps);
} else {
FDASSERT(false, "%s is not implemented for DPMSolverMultistepScheduler",
beta_schedule.c_str());
}

Scalar one = static_cast<float>(1.0);
alphas_ = FDTensor(one) - betas_;
function::Cumprod(alphas_, &alphas_cumprod_);
function::Sqrt(alphas_cumprod_, &alpha_t_);
function::Sqrt(FDTensor(one) - alphas_cumprod_, &sigma_t_);
FDTensor alpha_t_log, sigma_t_log;
function::Log(alpha_t_, &alpha_t_log);
function::Log(sigma_t_, &sigma_t_log);
lambda_t_ = alpha_t_log - sigma_t_log;

FDASSERT(algorithm_type_ == "dpmsolver" || algorithm_type_ == "dpmsolver++",
"%s does is not implemented for DPMSolverMultistepScheduler",
algorithm_type_.c_str());
FDASSERT(solver_type_ == "midpoint" || solver_type_ == "heun",
"%s does is not implemented for DPMSolverMultistepScheduler",
solver_type_.c_str());
num_inference_steps_ = -1;

function::Linspace(0, num_train_timesteps_ - 1, num_train_timesteps_,
&timesteps_);
// Reverse timesteps
float* timesteps_data = reinterpret_cast<float*>(timesteps_.Data());
std::reverse(timesteps_data, timesteps_data + timesteps_.Numel());

model_outputs_.resize(solver_order_);
lower_order_nums_ = 0;
}

void DPMSolverMultistepScheduler::ConvertModelOutput(
const FDTensor& model_output, int timestep, const FDTensor& sample,
FDTensor* out) {
if (algorithm_type_ == "dpmsolver++") {
FDTensor x0_pred;
if (predict_epsilon_) {
FDTensor alpha_t, sigma_t;
function::Slice(alpha_t_, {0}, {timestep}, {timestep + 1}, &alpha_t);
function::Slice(sigma_t_, {0}, {timestep}, {timestep + 1}, &sigma_t);
alpha_t.Squeeze();
sigma_t_.Squeeze();
x0_pred = (sample - sigma_t * model_output) / alpha_t;
} else {
x0_pred = model_output;
}
if (thresholding_) {
FDTensor dynamic_max_val, x0_pred_abs;
function::Abs(x0_pred, &x0_pred_abs);
x0_pred_abs.Reshape({x0_pred_abs.Shape()[0], -1});
function::Quantile(x0_pred_abs, {dynamic_thresholding_ratio_}, {1},
&dynamic_max_val);

FDTensor max_value, dy_max_val;
function::FullLike(dynamic_max_val, sample_max_value_, &max_value,
dynamic_max_val.Dtype());
function::Maximum(dynamic_max_val, max_value, &dy_max_val);
int expand_dims = x0_pred.Shape().size() - 1;
for (int i = 0; i < expand_dims; ++i) {
dy_max_val.ExpandDim(dy_max_val.Shape().size());
}
float clip_max = reinterpret_cast<float*>(dy_max_val.Data())[0];
function::Clip(x0_pred, -clip_max, clip_max, &x0_pred);
x0_pred = x0_pred / dy_max_val;
}
*out = std::move(x0_pred);
} else if (algorithm_type_ == "dpmsolver") {
if (predict_epsilon_) {
*out = model_output;
} else {
FDTensor alpha_t, sigma_t;
function::Slice(alpha_t_, {0}, {timestep}, {timestep + 1}, &alpha_t);
function::Slice(sigma_t_, {0}, {timestep}, {timestep + 1}, &sigma_t);
alpha_t.Squeeze();
sigma_t_.Squeeze();
*out = (sample - alpha_t * model_output) / sigma_t;
}
}
}

void DPMSolverMultistepScheduler::SetTimesteps(int num_inference_steps) {
num_inference_steps_ = num_inference_steps;
function::Linspace(0, num_train_timesteps_ - 1, num_inference_steps + 1,
&timesteps_);
function::Round(timesteps_, &timesteps_);
// Reverse timesteps
float* timesteps_data = reinterpret_cast<float*>(timesteps_.Data());
std::reverse(timesteps_data, timesteps_data + timesteps_.Numel());
FDTensor timestep_tmp;
timestep_tmp.Allocate({num_inference_steps}, timesteps_.Dtype());
float* timestep_tmp_data = reinterpret_cast<float*>(timestep_tmp.Data());
std::copy(timesteps_data, timesteps_data + num_inference_steps,
timestep_tmp_data);
timesteps_ = std::move(timestep_tmp);

function::Cast(timesteps_, &timesteps_, FDDataType::INT64);

model_outputs_.clear();
model_outputs_.resize(solver_order_);

lower_order_nums_ = 0;
}

void DPMSolverMultistepScheduler::Step(const FDTensor& model_output,
int timestep, const FDTensor& sample,
FDTensor* prev_sample) {}

} // namespace fastdeploy
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "fastdeploy/core/fd_tensor.h"
#include "./scheduler.h"

namespace fastdeploy {

class DPMSolverMultistepScheduler : public Scheduler {
public:
DPMSolverMultistepScheduler(int num_train_timesteps = 1000,
float beta_start = 0.0001, float beta_end = 0.02,
const std::string& beta_schedule = "linear",
const std::vector<float>& trained_betas = {},
int solver_order = 2, bool predict_epsilon = true,
bool thresholding = false,
float dynamic_thresholding_ratio = 0.995,
float sample_max_value = 1.0,
const std::string& algorithm_type = "dpmsolver++",
const std::string& solver_type = "midpoint",
bool lower_order_final = true);
void BetaForAlphaBar(FDTensor* out, int num_diffusion_timesteps,
float max_beta = 0.999);
void ConvertModelOutput(const FDTensor& model_output, int timestep,
const FDTensor& sample, FDTensor* out);

void SetTimesteps(int num_inference_steps) override;
void Step(const FDTensor& model_output, int timestep, const FDTensor& sample,
FDTensor* prev_sample) override;

private:
FDTensor betas_;
FDTensor alphas_;
FDTensor alphas_cumprod_;
FDTensor alpha_t_;
FDTensor sigma_t_;
FDTensor lambda_t_;
int num_inference_steps_;
FDTensor timesteps_;
int lower_order_nums_;
std::vector<FDTensor> model_outputs_;

int num_train_timesteps_;
float beta_start_;
float beta_end_;
std::string beta_schedule_;
int solver_order_;
bool predict_epsilon_;
bool thresholding_;
float dynamic_thresholding_ratio_;
float sample_max_value_;
std::string algorithm_type_;
std::string solver_type_;
bool lower_order_final_;
};

} // namespace fastdeploy
18 changes: 18 additions & 0 deletions examples/multimodal/stable_diffusion/cpp/main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dpm_solver_multistep_scheduler.h"
#include <iostream>

int main() { return 0; }
27 changes: 27 additions & 0 deletions examples/multimodal/stable_diffusion/cpp/scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "fastdeploy/core/fd_tensor.h"

namespace fastdeploy {

class Scheduler {
virtual void SetTimesteps(int num_inference_steps) = 0;
virtual void Step(const FDTensor& model_output, int timestep,
const FDTensor& sample, FDTensor* prev_sample) = 0;
};

} // namespace fastdeploy
7 changes: 4 additions & 3 deletions fastdeploy/function/clip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ void ClipKernel(const FDTensor& x, double min, double max, FDTensor* out) {
"max should be greater than or equal to min. But received min = %f, "
"max = %f",
static_cast<float>(min_), static_cast<float>(max_));

out->Allocate(x.Shape(), x.Dtype());
FDTensor tmp;
tmp.Allocate(x.Shape(), x.Dtype());
const T* x_data = reinterpret_cast<const T*>(x.Data());

int64_t numel = x.Numel();
T* out_data = reinterpret_cast<T*>(out->Data());
T* out_data = reinterpret_cast<T*>(tmp.Data());

std::transform(x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_));
*out = std::move(tmp);
}

void Clip(const FDTensor& x, double min, double max, FDTensor* out) {
Expand Down
8 changes: 5 additions & 3 deletions fastdeploy/function/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ void CommonElementwiseBroadcastForward(const FDTensor& x, const FDTensor& y,
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
z->Allocate(out_dims_array, TypeToDataType<OutType>::dtype);
FDTensor tmp;
tmp.Allocate(out_dims_array, TypeToDataType<OutType>::dtype);
CommonForwardBroadcastCPU<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(),
max_dim, func, is_xsize_larger);
x, y, &tmp, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim, func, is_xsize_larger);
*z = std::move(tmp);
}

template <typename Functor, typename T, typename OutType = T>
Expand Down

0 comments on commit 9623de3

Please sign in to comment.