Skip to content

Commit

Permalink
[Diffusion] Add StableDiffusionInpaint pipeline (#760)
Browse files Browse the repository at this point in the history
* Update Inpaint pipeline

* Update concat

* Add GaussianRandomKernel

* Update GaussianRandom

* Add vae endoder

* Add unet infer

* Add vae decoder predict

* add PrepareMaskAndMaskedImage

* Add imwrite

* Add time counter

* Fix pipeline

* use FDTensor move

* Fix scaled_linear dpm solver

* Add RGB2BGR
  • Loading branch information
joey12300 authored Dec 2, 2022
1 parent 9531e99 commit d74e120
Show file tree
Hide file tree
Showing 12 changed files with 639 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
FDDataType::FP32);
} else if (beta_schedule == "scaled_linear") {
function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_,
FDDataType::FP32);
function::Linspace(std::sqrt(beta_start), std::sqrt(beta_end),
num_train_timesteps, &betas_, FDDataType::FP32);
betas_ = betas_ * betas_;
} else if (beta_schedule == "squaredcos_cap_v2") {
BetaForAlphaBar(&betas_, num_train_timesteps);
Expand Down Expand Up @@ -96,6 +96,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler(
lower_order_nums_ = 0;
}

float DPMSolverMultistepScheduler::InitNoiseSigma() { return 1.0; }

void DPMSolverMultistepScheduler::ConvertModelOutput(
const FDTensor& model_output, int timestep, const FDTensor& sample,
FDTensor* out) {
Expand Down Expand Up @@ -314,7 +316,6 @@ void DPMSolverMultistepScheduler::Step(const FDTensor& model_output,
if (timesteps_iter - timesteps_data < timesteps_.Numel()) {
step_index = timesteps_iter - timesteps_data;
}

int64_t prev_timestep = 0;
if (step_index != timesteps_.Numel() - 1) {
prev_timestep = timesteps_data[step_index + 1];
Expand Down Expand Up @@ -392,4 +393,6 @@ void DPMSolverMultistepScheduler::AddNoise(const FDTensor& original_samples,
*out = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise;
}

FDTensor DPMSolverMultistepScheduler::GetTimesteps() { return timesteps_; }

} // namespace fastdeploy
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class DPMSolverMultistepScheduler : public Scheduler {
const std::vector<FDTensor>& timesteps = {}) override;
void AddNoise(const FDTensor& original_samples, const FDTensor& noise,
const FDTensor& timesteps, FDTensor* out) override;
float InitNoiseSigma() override;
FDTensor GetTimesteps() override;
struct Config {
int num_train_timesteps_;
float beta_start_;
Expand Down
126 changes: 112 additions & 14 deletions examples/multimodal/stable_diffusion/cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,121 @@
// limitations under the License.

#include "dpm_solver_multistep_scheduler.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "fastdeploy/utils/perf.h"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "pipeline_stable_diffusion_inpaint.h"
#include <iostream>
#include <memory>
#include <sstream>
#include <string>

template <typename T> std::string Str(const T* value, int size) {
std::ostringstream oss;
oss << "[ " << value[0];
for (int i = 1; i < size; ++i) {
oss << " ," << value[i];
}
oss << " ]";
return oss.str();
}

std::unique_ptr<fastdeploy::Runtime>
CreateRuntime(const std::string& model_file, const std::string& params_file,
bool use_paddle_backend = true) {
fastdeploy::RuntimeOption runtime_option;
runtime_option.SetModelPath(model_file, params_file,
fastdeploy::ModelFormat::PADDLE);
runtime_option.UseGpu();
if (use_paddle_backend) {
runtime_option.UsePaddleBackend();
} else {
runtime_option.UseOrtBackend();
}
std::unique_ptr<fastdeploy::Runtime> runtime =
std::unique_ptr<fastdeploy::Runtime>(new fastdeploy::Runtime());
if (!runtime->Init(runtime_option)) {
std::cerr << "--- Init FastDeploy Runitme Failed! "
<< "\n--- Model: " << model_file << std::endl;
return nullptr;
} else {
std::cout << "--- Init FastDeploy Runitme Done! "
<< "\n--- Model: " << model_file << std::endl;
}
return runtime;
}

int main() {
fastdeploy::DPMSolverMultistepScheduler dpm(
/* num_train_timesteps */ 1000,
/* beta_start = */ 0.00085,
/* beta_end = */ 0.012,
/* beta_schedule = */ "scaled_linear",
/* trained_betas = */ {},
/* solver_order = */ 2,
/* predict_epsilon = */ true,
/* thresholding = */ false,
/* dynamic_thresholding_ratio = */ 0.995,
/* sample_max_value = */ 1.0,
/* algorithm_type = */ "dpmsolver++",
/* solver_type = */ "midpoint",
/* lower_order_final = */ true);
// 1. Init scheduler
std::unique_ptr<fastdeploy::Scheduler> dpm(
new fastdeploy::DPMSolverMultistepScheduler(
/* num_train_timesteps */ 1000,
/* beta_start = */ 0.00085,
/* beta_end = */ 0.012,
/* beta_schedule = */ "scaled_linear",
/* trained_betas = */ {},
/* solver_order = */ 2,
/* predict_epsilon = */ true,
/* thresholding = */ false,
/* dynamic_thresholding_ratio = */ 0.995,
/* sample_max_value = */ 1.0,
/* algorithm_type = */ "dpmsolver++",
/* solver_type = */ "midpoint",
/* lower_order_final = */ true));

// 2. Init text encoder runtime
std::string text_model_file = "sd15_inpaint/text_encoder/inference.pdmodel";
std::string text_params_file =
"sd15_inpaint/text_encoder/inference.pdiparams";
std::unique_ptr<fastdeploy::Runtime> text_encoder_runtime =
CreateRuntime(text_model_file, text_params_file, false);

// 3. Init vae encoder runtime
std::string vae_encoder_model_file =
"sd15_inpaint/vae_encoder/inference.pdmodel";
std::string vae_encoder_params_file =
"sd15_inpaint/vae_encoder/inference.pdiparams";
std::unique_ptr<fastdeploy::Runtime> vae_encoder_runtime =
CreateRuntime(vae_encoder_model_file, vae_encoder_params_file);

// 4. Init vae decoder runtime
std::string vae_decoder_model_file =
"sd15_inpaint/vae_decoder/inference.pdmodel";
std::string vae_decoder_params_file =
"sd15_inpaint/vae_decoder/inference.pdiparams";
std::unique_ptr<fastdeploy::Runtime> vae_decoder_runtime =
CreateRuntime(vae_decoder_model_file, vae_decoder_params_file);

// 5. Init unet runtime
std::string unet_model_file = "sd15_inpaint/unet/inference.pdmodel";
std::string unet_params_file = "sd15_inpaint/unet/inference.pdiparams";
std::unique_ptr<fastdeploy::Runtime> unet_runtime =
CreateRuntime(unet_model_file, unet_params_file);

// 6. Init fast tokenizer
paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer(
"clip/vocab.json", "clip/merges.txt", /* max_length = */ 77);
fastdeploy::StableDiffusionInpaintPipeline pipe(
std::move(vae_encoder_runtime), std::move(vae_decoder_runtime),
std::move(text_encoder_runtime), std::move(unet_runtime),
/* scheduler = */ std::move(dpm), tokenizer);

// 7. Read images
auto image = cv::imread("overture-creations.png");
auto mask_image = cv::imread("overture-creations-mask.png");

// 8. Predict
std::vector<std::string> prompts = {
"Face of a yellow cat, high resolution, sitting on a park bench"};
std::vector<fastdeploy::FDTensor> outputs;
fastdeploy::TimeCounter tc;
tc.Start();
pipe.Predict(prompts, image, mask_image, &outputs, /* height = */ 512,
/* width = */ 512, /* num_inference_steps = */ 50);
tc.End();
tc.PrintInfo();
fastdeploy::vision::FDMat mat = fastdeploy::vision::FDMat::Create(outputs[0]);
cv::imwrite("cat_on_bench_new.png", *mat.GetOpenCVMat());
return 0;
}
Loading

0 comments on commit d74e120

Please sign in to comment.