From d74e1209ae682427869008df3d0c1590aa1fdfa9 Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Fri, 2 Dec 2022 19:30:32 +0800 Subject: [PATCH] [Diffusion] Add StableDiffusionInpaint pipeline (#760) * Update Inpaint pipeline * Update concat * Add GaussianRandomKernel * Update GaussianRandom * Add vae endoder * Add unet infer * Add vae decoder predict * add PrepareMaskAndMaskedImage * Add imwrite * Add time counter * Fix pipeline * use FDTensor move * Fix scaled_linear dpm solver * Add RGB2BGR --- .../cpp/dpm_solver_multistep_scheduler.cc | 9 +- .../cpp/dpm_solver_multistep_scheduler.h | 2 + .../multimodal/stable_diffusion/cpp/main.cc | 126 ++++++- .../cpp/pipeline_stable_diffusion_inpaint.cc | 322 ++++++++++++++++++ .../cpp/pipeline_stable_diffusion_inpaint.h | 61 ++++ .../stable_diffusion/cpp/scheduler.h | 3 + fastdeploy/function/concat.cc | 8 +- fastdeploy/function/functions.h | 1 + fastdeploy/function/gaussian_random.cc | 46 +++ fastdeploy/function/gaussian_random.h | 36 ++ fastdeploy/function/tile.cc | 7 +- fastdeploy/utils/utils.h | 77 +++-- 12 files changed, 639 insertions(+), 59 deletions(-) create mode 100644 examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.cc create mode 100644 examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.h create mode 100644 fastdeploy/function/gaussian_random.cc create mode 100644 fastdeploy/function/gaussian_random.h diff --git a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc index cb6cf970b7..b61c5b5db1 100644 --- a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc +++ b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.cc @@ -57,8 +57,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler( function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_, FDDataType::FP32); } else if (beta_schedule == "scaled_linear") { - function::Linspace(beta_start, beta_end, num_train_timesteps, &betas_, - FDDataType::FP32); + function::Linspace(std::sqrt(beta_start), std::sqrt(beta_end), + num_train_timesteps, &betas_, FDDataType::FP32); betas_ = betas_ * betas_; } else if (beta_schedule == "squaredcos_cap_v2") { BetaForAlphaBar(&betas_, num_train_timesteps); @@ -96,6 +96,8 @@ DPMSolverMultistepScheduler::DPMSolverMultistepScheduler( lower_order_nums_ = 0; } +float DPMSolverMultistepScheduler::InitNoiseSigma() { return 1.0; } + void DPMSolverMultistepScheduler::ConvertModelOutput( const FDTensor& model_output, int timestep, const FDTensor& sample, FDTensor* out) { @@ -314,7 +316,6 @@ void DPMSolverMultistepScheduler::Step(const FDTensor& model_output, if (timesteps_iter - timesteps_data < timesteps_.Numel()) { step_index = timesteps_iter - timesteps_data; } - int64_t prev_timestep = 0; if (step_index != timesteps_.Numel() - 1) { prev_timestep = timesteps_data[step_index + 1]; @@ -392,4 +393,6 @@ void DPMSolverMultistepScheduler::AddNoise(const FDTensor& original_samples, *out = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise; } +FDTensor DPMSolverMultistepScheduler::GetTimesteps() { return timesteps_; } + } // namespace fastdeploy diff --git a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h index c6f037feea..0775ba1ee7 100644 --- a/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h +++ b/examples/multimodal/stable_diffusion/cpp/dpm_solver_multistep_scheduler.h @@ -54,6 +54,8 @@ class DPMSolverMultistepScheduler : public Scheduler { const std::vector& timesteps = {}) override; void AddNoise(const FDTensor& original_samples, const FDTensor& noise, const FDTensor& timesteps, FDTensor* out) override; + float InitNoiseSigma() override; + FDTensor GetTimesteps() override; struct Config { int num_train_timesteps_; float beta_start_; diff --git a/examples/multimodal/stable_diffusion/cpp/main.cc b/examples/multimodal/stable_diffusion/cpp/main.cc index 3c7d33029f..62bcdfd1db 100644 --- a/examples/multimodal/stable_diffusion/cpp/main.cc +++ b/examples/multimodal/stable_diffusion/cpp/main.cc @@ -13,23 +13,121 @@ // limitations under the License. #include "dpm_solver_multistep_scheduler.h" +#include "fastdeploy/vision/common/processors/mat.h" +#include "fastdeploy/utils/perf.h" +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "pipeline_stable_diffusion_inpaint.h" #include +#include +#include +#include + +template std::string Str(const T* value, int size) { + std::ostringstream oss; + oss << "[ " << value[0]; + for (int i = 1; i < size; ++i) { + oss << " ," << value[i]; + } + oss << " ]"; + return oss.str(); +} + +std::unique_ptr +CreateRuntime(const std::string& model_file, const std::string& params_file, + bool use_paddle_backend = true) { + fastdeploy::RuntimeOption runtime_option; + runtime_option.SetModelPath(model_file, params_file, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UseGpu(); + if (use_paddle_backend) { + runtime_option.UsePaddleBackend(); + } else { + runtime_option.UseOrtBackend(); + } + std::unique_ptr runtime = + std::unique_ptr(new fastdeploy::Runtime()); + if (!runtime->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << model_file << std::endl; + return nullptr; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << model_file << std::endl; + } + return runtime; +} int main() { - fastdeploy::DPMSolverMultistepScheduler dpm( - /* num_train_timesteps */ 1000, - /* beta_start = */ 0.00085, - /* beta_end = */ 0.012, - /* beta_schedule = */ "scaled_linear", - /* trained_betas = */ {}, - /* solver_order = */ 2, - /* predict_epsilon = */ true, - /* thresholding = */ false, - /* dynamic_thresholding_ratio = */ 0.995, - /* sample_max_value = */ 1.0, - /* algorithm_type = */ "dpmsolver++", - /* solver_type = */ "midpoint", - /* lower_order_final = */ true); + // 1. Init scheduler + std::unique_ptr dpm( + new fastdeploy::DPMSolverMultistepScheduler( + /* num_train_timesteps */ 1000, + /* beta_start = */ 0.00085, + /* beta_end = */ 0.012, + /* beta_schedule = */ "scaled_linear", + /* trained_betas = */ {}, + /* solver_order = */ 2, + /* predict_epsilon = */ true, + /* thresholding = */ false, + /* dynamic_thresholding_ratio = */ 0.995, + /* sample_max_value = */ 1.0, + /* algorithm_type = */ "dpmsolver++", + /* solver_type = */ "midpoint", + /* lower_order_final = */ true)); + + // 2. Init text encoder runtime + std::string text_model_file = "sd15_inpaint/text_encoder/inference.pdmodel"; + std::string text_params_file = + "sd15_inpaint/text_encoder/inference.pdiparams"; + std::unique_ptr text_encoder_runtime = + CreateRuntime(text_model_file, text_params_file, false); + + // 3. Init vae encoder runtime + std::string vae_encoder_model_file = + "sd15_inpaint/vae_encoder/inference.pdmodel"; + std::string vae_encoder_params_file = + "sd15_inpaint/vae_encoder/inference.pdiparams"; + std::unique_ptr vae_encoder_runtime = + CreateRuntime(vae_encoder_model_file, vae_encoder_params_file); + + // 4. Init vae decoder runtime + std::string vae_decoder_model_file = + "sd15_inpaint/vae_decoder/inference.pdmodel"; + std::string vae_decoder_params_file = + "sd15_inpaint/vae_decoder/inference.pdiparams"; + std::unique_ptr vae_decoder_runtime = + CreateRuntime(vae_decoder_model_file, vae_decoder_params_file); + + // 5. Init unet runtime + std::string unet_model_file = "sd15_inpaint/unet/inference.pdmodel"; + std::string unet_params_file = "sd15_inpaint/unet/inference.pdiparams"; + std::unique_ptr unet_runtime = + CreateRuntime(unet_model_file, unet_params_file); + + // 6. Init fast tokenizer + paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer( + "clip/vocab.json", "clip/merges.txt", /* max_length = */ 77); + fastdeploy::StableDiffusionInpaintPipeline pipe( + std::move(vae_encoder_runtime), std::move(vae_decoder_runtime), + std::move(text_encoder_runtime), std::move(unet_runtime), + /* scheduler = */ std::move(dpm), tokenizer); + + // 7. Read images + auto image = cv::imread("overture-creations.png"); + auto mask_image = cv::imread("overture-creations-mask.png"); + // 8. Predict + std::vector prompts = { + "Face of a yellow cat, high resolution, sitting on a park bench"}; + std::vector outputs; + fastdeploy::TimeCounter tc; + tc.Start(); + pipe.Predict(prompts, image, mask_image, &outputs, /* height = */ 512, + /* width = */ 512, /* num_inference_steps = */ 50); + tc.End(); + tc.PrintInfo(); + fastdeploy::vision::FDMat mat = fastdeploy::vision::FDMat::Create(outputs[0]); + cv::imwrite("cat_on_bench_new.png", *mat.GetOpenCVMat()); return 0; } \ No newline at end of file diff --git a/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.cc b/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.cc new file mode 100644 index 0000000000..cbe352a710 --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.cc @@ -0,0 +1,322 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pipeline_stable_diffusion_inpaint.h" +#include "fastdeploy/function/functions.h" +#include "fastdeploy/vision/common/processors/color_space_convert.h" +#include "fastdeploy/vision/common/processors/mat.h" +#include "fastdeploy/vision/common/processors/resize.h" +#include + +using namespace paddlenlp; + +namespace fastdeploy { + +static constexpr int NUM_LATENT_CHANNELS = 4; +static constexpr int NUM_UNET_INPUT_CHANNELS = 9; + +void StableDiffusionInpaintPipeline::PrepareMaskAndMaskedImage( + const cv::Mat& image, const cv::Mat& mask_mat, + const std::vector& shape, FDTensor* mask, FDTensor* mask_image) { + vision::FDMat image_fdmat(image); + vision::BGR2RGB::Run(&image_fdmat, vision::ProcLib::OPENCV); + vision::Resize::Run(&image_fdmat, shape[1] * 8, shape[0] * 8, -1.0f, -1.0f, + cv::INTER_NEAREST, false, vision::ProcLib::OPENCV); + image_fdmat.ShareWithTensor(mask_image); + + vision::FDMat mask_fdmat(mask_mat); + vision::BGR2GRAY::Run(&mask_fdmat, vision::ProcLib::OPENCV); + vision::Resize::Run(&mask_fdmat, shape[1] * 8, shape[0] * 8, -1.0f, -1.0f, + cv::INTER_NEAREST, false, vision::ProcLib::OPENCV); + FDTensor image_mask; + mask_fdmat.ShareWithTensor(&image_mask); + function::Cast(image_mask, &image_mask, FDDataType::FP32); + std::vector float_mask(image_mask.Numel(), 0); + float* image_mask_ptr = reinterpret_cast(image_mask.Data()); + for (int i = 0; i < image_mask.Numel(); ++i) { + if (image_mask_ptr[i] < 127.5) { + float_mask[i] = 1; + } + } + image_mask.SetExternalData({1, 1, shape[1] * 8, shape[0] * 8}, + FDDataType::FP32, float_mask.data()); + + // Set mask_image + mask_image->ExpandDim(); + function::Transpose(*mask_image, mask_image, {0, 3, 1, 2}); + function::Cast(*mask_image, mask_image, FDDataType::FP32); + *mask_image = *mask_image / 127.5f - 1.0f; + *mask_image = *mask_image * image_mask; + + // Set mask + vision::FDMat mask_fdmat_t(mask_mat); + vision::BGR2GRAY::Run(&mask_fdmat_t, vision::ProcLib::OPENCV); + vision::Resize::Run(&mask_fdmat_t, shape[1], shape[0], -1.0f, -1.0f, + cv::INTER_NEAREST, false, vision::ProcLib::OPENCV); + mask_fdmat_t.ShareWithTensor(mask); + function::Cast(*mask, mask, FDDataType::FP32); + *mask = *mask / 255.0f; + mask->ExpandDim(); + function::Transpose(*mask, mask, {0, 3, 1, 2}); + float* mask_data = reinterpret_cast(mask->Data()); + for (int i = 0; i < mask->Numel(); ++i) { + if (mask_data[i] < 0.5) { + mask_data[i] = 0; + } else { + mask_data[i] = 1; + } + } +} + +StableDiffusionInpaintPipeline::StableDiffusionInpaintPipeline( + std::unique_ptr vae_encoder, std::unique_ptr vae_decoder, + std::unique_ptr text_encoder, std::unique_ptr unet, + std::unique_ptr scheduler, + const paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer& + tokenizer) + : vae_encoder_(std::move(vae_encoder)), + vae_decoder_(std::move(vae_decoder)), + text_encoder_(std::move(text_encoder)), unet_(std::move(unet)), + scheduler_(std::move(scheduler)), tokenizer_(tokenizer) {} + +void StableDiffusionInpaintPipeline::Predict( + const std::vector& prompts, const cv::Mat& image, + const cv::Mat& mask_image, std::vector* output_images, int height, + int width, int num_inference_steps, float guidance_scale, + const std::vector& negative_prompt, int num_images_per_prompt, + float eta, uint32_t max_length, const FDTensor* latents, bool output_cv_mat, + callback_ptr callback, int callback_steps) { + int batch_size = prompts.size(); + FDASSERT(batch_size >= 1, "prompts should not be empty"); + FDASSERT( + height % 8 == 0 && width % 8 == 0, + "`height` and `width` have to be divisible by 8 but are {%d} and {%d}.", + height, width); + FDASSERT(callback_steps > 0, + "`callback_steps` has to be a positive integer but is {%d}", + callback_steps); + + // Setting tokenizer attr + if (max_length == 0) { + tokenizer_.EnablePadMethod(fast_tokenizer::core::RIGHT, + tokenizer_.GetPadTokenId(), 0, + tokenizer_.GetPadToken(), nullptr, nullptr); + tokenizer_.DisableTruncMethod(); + } else { + tokenizer_.EnablePadMethod(fast_tokenizer::core::RIGHT, + tokenizer_.GetPadTokenId(), 0, + tokenizer_.GetPadToken(), &max_length, nullptr); + tokenizer_.EnableTruncMethod(max_length, 0, fast_tokenizer::core::RIGHT, + fast_tokenizer::core::LONGEST_FIRST); + } + std::vector encodings; + tokenizer_.EncodeBatchStrings(prompts, &encodings); + + std::vector input_ids; + for (auto& encoding : encodings) { + auto curr_ids = encoding.GetIds(); + input_ids.insert(input_ids.end(), curr_ids.begin(), curr_ids.end()); + } + encodings.clear(); + // Get text encoder output + FDTensor text_intput_ids; + std::vector inputs(1); + inputs[0].SetExternalData({batch_size, max_length}, FDDataType::INT64, + input_ids.data()); + + TensorInfo text_info = text_encoder_->GetInputInfo(0); + inputs[0].name = text_info.name; + int output_size = text_encoder_->GetOutputInfos().size(); + std::vector outputs(output_size); + text_encoder_->Infer(inputs, &outputs); + + FDTensor text_embeddings; + function::Tile(outputs[0], {num_images_per_prompt, 1, 1}, &text_embeddings); + + // here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + // of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + // corresponds to doing no classifier free guidance. + bool do_classifier_free_guidance = guidance_scale > 1.0; + if (do_classifier_free_guidance) { + std::vector uncond_tokens; + if (negative_prompt.size() == 0) { + uncond_tokens = {""}; + } else if (negative_prompt.size() != batch_size) { + FDASSERT(false, + "negative_prompt has batch size %d, but prompt has batch size " + "%d. Please make sure that passed `negative_prompt` matches the " + "batch size of `prompt`.", + prompts.size(), negative_prompt.size()); + } else { + uncond_tokens = negative_prompt; + } + tokenizer_.EncodeBatchStrings(uncond_tokens, &encodings); + input_ids.clear(); + for (auto& encoding : encodings) { + auto curr_ids = encoding.GetIds(); + input_ids.insert(input_ids.end(), curr_ids.begin(), curr_ids.end()); + } + inputs[0].SetExternalData({batch_size, max_length}, FDDataType::INT64, + input_ids.data()); + text_encoder_->Infer(inputs, &outputs); + FDTensor uncond_embeddings; + function::Tile(outputs[0], {num_images_per_prompt, 1, 1}, + &uncond_embeddings); + function::Concat({uncond_embeddings, text_embeddings}, &text_embeddings); + } + std::vector latents_shape = {batch_size * num_images_per_prompt, + NUM_LATENT_CHANNELS, height / 8, + width / 8}; + auto latents_dtype = text_embeddings.Dtype(); + FDTensor actual_latents; + if (latents == nullptr) { + function::GaussianRandom(latents_shape, &actual_latents, latents_dtype); + } else { + bool result = std::equal(latents_shape.begin(), latents_shape.end(), + latents->Shape().begin()); + FDASSERT(result, "Unexpected latents shape, got %s, expected %s", + Str(latents_shape).c_str(), Str(latents->Shape()).c_str()); + actual_latents = *latents; + } + FDTensor mask_t, mask_image_t; + PrepareMaskAndMaskedImage(image, mask_image, {height / 8, width / 8}, &mask_t, + &mask_image_t); + function::Cast(mask_t, &mask_t, actual_latents.Dtype()); + function::Cast(mask_image_t, &mask_image_t, actual_latents.Dtype()); + + // Get vae encoder output + TensorInfo vae_encoder_info = vae_encoder_->GetInputInfo(0); + mask_image_t.name = vae_encoder_info.name; + outputs.resize(vae_encoder_->GetOutputInfos().size()); + inputs = {mask_image_t}; + vae_encoder_->Infer(inputs, &outputs); + FDTensor masked_image_latents = 0.18215f * outputs[0]; + + std::vector mask_shape(mask_t.Shape().size(), 1); + mask_shape[0] = batch_size * num_images_per_prompt; + function::Tile(mask_t, mask_shape, &mask_t); + + std::vector mask_image_shape(masked_image_latents.Shape().size(), 1); + mask_image_shape[0] = batch_size * num_images_per_prompt; + function::Tile(masked_image_latents, mask_image_shape, &masked_image_latents); + + if (do_classifier_free_guidance) { + function::Concat({mask_t, mask_t}, &mask_t); + function::Concat({masked_image_latents, masked_image_latents}, + &masked_image_latents); + } + int num_channels_mask = mask_t.Shape()[1]; + int num_channels_masked_image = masked_image_latents.Shape()[1]; + FDASSERT( + NUM_LATENT_CHANNELS + num_channels_mask + num_channels_masked_image == + NUM_UNET_INPUT_CHANNELS, + "Incorrect configuration settings! The config of `pipeline.unet` expects" + " %d but received `num_channels_latents`: %d + `num_channels_mask`: %d " + "+ `num_channels_masked_image`: %d" + " = %d. Please verify the config of `pipeline.unet` or your `mask_image` " + "or `image` input.", + NUM_UNET_INPUT_CHANNELS, NUM_LATENT_CHANNELS, num_channels_mask, + num_channels_masked_image, + NUM_LATENT_CHANNELS + num_channels_mask + num_channels_masked_image); + + // set timesteps + scheduler_->SetTimesteps(num_inference_steps); + + // scale the initial noise by the standard deviation required by the scheduler + actual_latents = actual_latents * scheduler_->InitNoiseSigma(); + + auto timestep = scheduler_->GetTimesteps(); + int64_t* timestep_data = reinterpret_cast(timestep.Data()); + outputs.resize(unet_->GetOutputInfos().size()); + inputs.resize(unet_->GetInputInfos().size()); + inputs[2] = std::move(text_embeddings); + auto unet_infos = unet_->GetInputInfos(); + for (int i = 0; i < timestep.Numel(); ++i) { + FDTensor t; + function::Slice(timestep, {0}, {i}, &t); + inputs[1] = t; + // expand the latents if we are doing classifier free guidance + FDTensor latent_model_input; + if (do_classifier_free_guidance) { + function::Concat({actual_latents, actual_latents}, &latent_model_input); + } else { + latent_model_input = actual_latents; + } + // concat latents, mask, masked_image_latnets in the channel dimension + function::Concat({latent_model_input, mask_t, masked_image_latents}, + &latent_model_input, 1); + scheduler_->ScaleModelInput(latent_model_input, &latent_model_input, {t}); + inputs[0] = std::move(latent_model_input); + // predict the noise residual + for (int i = 0; i < unet_infos.size(); ++i) { + inputs[i].name = unet_infos[i].name; + } + unet_->Infer(inputs, &outputs); + FDTensor noise_pred = std::move(outputs[0]); + // perform guidance + if (do_classifier_free_guidance) { + std::vector noise_preds; + int dim0 = noise_pred.Shape()[0]; + function::Split(noise_pred, {dim0 - dim0 / 2, dim0 / 2}, &noise_preds); + noise_pred = + noise_preds[0] + guidance_scale * (noise_preds[1] - noise_preds[0]); + } + + // compute the previous noisy sample x_t -> x_t-1 + int64_t time = reinterpret_cast(t.Data())[0]; + scheduler_->Step(noise_pred, time, actual_latents, &actual_latents); + + // call the callback, if provided + if (callback != nullptr && i % callback_steps == 0) { + callback(i, time, &actual_latents); + } + } + actual_latents = (1.0f / 0.18215f) * actual_latents; + + // Get vae decoder output + int actual_latents_bs = actual_latents.Shape()[0]; + TensorInfo vae_decoder_info = vae_decoder_->GetInputInfo(0); + inputs.resize(1); + outputs.resize(vae_decoder_->GetOutputInfos().size()); + std::vector decoder_reuslt; + for (int i = 0; i < actual_latents_bs; ++i) { + function::Slice(actual_latents, {0}, {i}, {i + 1}, &inputs[0]); + inputs[0].name = vae_decoder_info.name; + vae_decoder_->Infer(inputs, &outputs); + decoder_reuslt.emplace_back(std::move(outputs[0])); + } + FDTensor output_image; + function::Concat(decoder_reuslt, &output_image); + + function::Clip(output_image / 2.0f + 0.5f, 0, 1, &output_image); + function::Transpose(output_image, &output_image, {0, 2, 3, 1}); + + if (output_cv_mat) { + output_image = output_image * 255.0f; + function::Round(output_image, &output_image); + function::Cast(output_image, &output_image, FDDataType::UINT8); + } + int output_batch_size = output_image.Shape()[0]; + output_images->resize(output_batch_size); + for (int i = 0; i < output_batch_size; ++i) { + function::Slice(output_image, {0}, {i}, &(*output_images)[i]); + vision::FDMat mask_fdmat_t = vision::FDMat::Create((*output_images)[i]); + vision::RGB2BGR::Run(&mask_fdmat_t, vision::ProcLib::OPENCV); + mask_fdmat_t.CopyToTensor(&(*output_images)[i]); + FDTensor sum; + function::Sum((*output_images)[i], &sum, {}, false, true); + FDINFO << "sum = " << ((float*)sum.Data())[0] << std::endl; + } +} +} // namespace fastdeploy diff --git a/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.h b/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.h new file mode 100644 index 0000000000..063c370f35 --- /dev/null +++ b/examples/multimodal/stable_diffusion/cpp/pipeline_stable_diffusion_inpaint.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "./scheduler.h" +#include "fast_tokenizer/tokenizers/clip_fast_tokenizer.h" +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/runtime.h" +#include "opencv2/core/core.hpp" +#include +#include +#include + +namespace fastdeploy { + +class StableDiffusionInpaintPipeline { + public: + typedef void (*callback_ptr)(int, int, FDTensor*); + + StableDiffusionInpaintPipeline( + std::unique_ptr vae_encoder, + std::unique_ptr vae_decoder, + std::unique_ptr text_encoder, std::unique_ptr unet, + std::unique_ptr scheduler, + const paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer& + tokenizer); + void Predict(const std::vector& prompts, const cv::Mat& image, + const cv::Mat& mask_image, std::vector* output_images, + int height = 512, int width = 512, int num_inference_steps = 50, + float guidance_scale = 7.5, + const std::vector& negative_prompt = {}, + int num_images_per_prompt = 1, float eta = 0.0, + uint32_t max_length = 77, const FDTensor* latents = nullptr, + bool output_cv_mat = true, callback_ptr callback = nullptr, + int callback_steps = 1); + + private: + void PrepareMaskAndMaskedImage(const cv::Mat& image, const cv::Mat& mask_mat, + const std::vector& shape, + FDTensor* mask, FDTensor* mask_image); + std::unique_ptr vae_encoder_; + std::unique_ptr vae_decoder_; + std::unique_ptr text_encoder_; + std::unique_ptr unet_; + std::unique_ptr scheduler_; + paddlenlp::fast_tokenizer::tokenizers_impl::ClipFastTokenizer tokenizer_; +}; + +} // namespace fastdeploy diff --git a/examples/multimodal/stable_diffusion/cpp/scheduler.h b/examples/multimodal/stable_diffusion/cpp/scheduler.h index 6a5cd2fed9..e4dc452def 100644 --- a/examples/multimodal/stable_diffusion/cpp/scheduler.h +++ b/examples/multimodal/stable_diffusion/cpp/scheduler.h @@ -19,13 +19,16 @@ namespace fastdeploy { class Scheduler { + public: virtual void SetTimesteps(int num_inference_steps) = 0; + virtual FDTensor GetTimesteps() = 0; virtual void Step(const FDTensor& model_output, int timestep, const FDTensor& sample, FDTensor* prev_sample) = 0; virtual void ScaleModelInput(const FDTensor& sample, FDTensor* out, const std::vector& timesteps = {}) = 0; virtual void AddNoise(const FDTensor& original_samples, const FDTensor& noise, const FDTensor& timesteps, FDTensor* out) = 0; + virtual float InitNoiseSigma() = 0; }; } // namespace fastdeploy diff --git a/fastdeploy/function/concat.cc b/fastdeploy/function/concat.cc index 295c3c25a4..4f07743942 100644 --- a/fastdeploy/function/concat.cc +++ b/fastdeploy/function/concat.cc @@ -88,11 +88,13 @@ template void ConcatKernel(const std::vector& input, FDTensor* output, int axis) { auto output_shape = ComputeAndCheckConcatOutputShape(input, axis); - output->Resize(output_shape, TypeToDataType::dtype, output->name, - input[0].device); + FDTensor output_tmp; + output_tmp.Resize(output_shape, TypeToDataType::dtype, output->name, + input[0].device); ConcatFunctor functor; - functor(input, axis, output); + functor(input, axis, &output_tmp); + *output = std::move(output_tmp); } void Concat(const std::vector& x, FDTensor* out, int axis) { diff --git a/fastdeploy/function/functions.h b/fastdeploy/function/functions.h index d2ffe6a0c1..a43407839f 100644 --- a/fastdeploy/function/functions.h +++ b/fastdeploy/function/functions.h @@ -21,6 +21,7 @@ #include "fastdeploy/function/elementwise.h" #include "fastdeploy/function/full.h" #include "fastdeploy/function/gather_scatter_along_axis.h" +#include "fastdeploy/function/gaussian_random.h" #include "fastdeploy/function/isfinite.h" #include "fastdeploy/function/linspace.h" #include "fastdeploy/function/math.h" diff --git a/fastdeploy/function/gaussian_random.cc b/fastdeploy/function/gaussian_random.cc new file mode 100644 index 0000000000..18657c4f2a --- /dev/null +++ b/fastdeploy/function/gaussian_random.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/function/gaussian_random.h" +#include +#include +#include + +namespace fastdeploy { +namespace function { + +template +void GaussianRandomKernel(const std::vector& shape, float mean, + float std, int seed, FDTensor* out) { + std::normal_distribution dist(mean, std); + + out->Allocate(shape, TypeToDataType::dtype); + int64_t size = out->Numel(); + T* data = reinterpret_cast(out->Data()); + std::mt19937_64 engine; + engine.seed(seed); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } +} + +void GaussianRandom(const std::vector& shape, FDTensor* out, + FDDataType dtype, float mean, float std, int seed) { + FD_VISIT_FLOAT_TYPES(dtype, "GaussianRandomKernel", [&]() { + GaussianRandomKernel(shape, mean, std, seed, out); + }); +} + +} // namespace function +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/function/gaussian_random.h b/fastdeploy/function/gaussian_random.h new file mode 100644 index 0000000000..85a4ff8a63 --- /dev/null +++ b/fastdeploy/function/gaussian_random.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { +namespace function { + +/** Output is obtained by gathering entries of axis of x indexed by index and + * concatenate them together. + @param shape The output tensor shape. + @param out the output tensor. + @param mean mean value of gaussian random + @param std standard value of gaussian random + @param seed The seed of random generator. + @param dtype The data type of the output Tensor. +*/ +void GaussianRandom(const std::vector& shape, FDTensor* out, + FDDataType dtype = FDDataType::FP32, float mean = 0.0f, + float std = 1.0f, int seed = 0); + +} // namespace function +} // namespace fastdeploy diff --git a/fastdeploy/function/tile.cc b/fastdeploy/function/tile.cc index 6437b4ec60..c6e3095c6f 100644 --- a/fastdeploy/function/tile.cc +++ b/fastdeploy/function/tile.cc @@ -49,6 +49,7 @@ void TileFunctor(const FDTensor& x, return; } + FDTensor out_tmp; Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; @@ -59,12 +60,14 @@ void TileFunctor(const FDTensor& x, out_shape[i] *= repeat_times[i]; } - out->Allocate(out_shape, x.Dtype()); + out_tmp.Allocate(out_shape, x.Dtype()); auto eigen_x = EigenTensor::From(x, x_shape); - auto eigen_out = EigenTensor::From(*out, out_shape); + auto eigen_out = EigenTensor::From(out_tmp, out_shape); const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); eigen_out.device(dev) = eigen_x.broadcast(bcast_dims); + + *out = std::move(out_tmp); } template diff --git a/fastdeploy/utils/utils.h b/fastdeploy/utils/utils.h index 9b2a0fe201..0dff28f8be 100644 --- a/fastdeploy/utils/utils.h +++ b/fastdeploy/utils/utils.h @@ -66,7 +66,8 @@ class FASTDEPLOY_DECL FDLogger { if (!verbose_ && line_ != "") { std::cout << line_ << std::endl; #ifdef __ANDROID__ - __android_log_print(ANDROID_LOG_INFO, prefix_.c_str(), "%s", line_.c_str()); + __android_log_print(ANDROID_LOG_INFO, prefix_.c_str(), "%s", + line_.c_str()); #endif } } @@ -122,6 +123,8 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, [&] { \ const auto& __dtype__ = TYPE; \ switch (__dtype__) { \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \ + __VA_ARGS__) \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \ __VA_ARGS__) \ FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ @@ -141,26 +144,26 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, } \ }() -#define FD_VISIT_INT_FLOAT_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \ - __VA_ARGS__) \ - default: \ - FDASSERT(false, \ - "Invalid enum data type. Expect to accept data type INT32, " \ - "INT64, FP32, FP64, UINT8 but receive type %s.", \ - Str(__dtype__).c_str()); \ - } \ +#define FD_VISIT_INT_FLOAT_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + default: \ + FDASSERT(false, \ + "Invalid enum data type. Expect to accept data type INT32, " \ + "INT64, FP32, FP64, UINT8 but receive type %s.", \ + Str(__dtype__).c_str()); \ + } \ }() #define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \ @@ -179,22 +182,22 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, } \ }() -#define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \ - __VA_ARGS__) \ - default: \ - FDASSERT(false, \ - "Invalid enum data type. Expect to accept data type INT32, " \ - "INT64, UINT8 but receive type %s.", \ - Str(__dtype__).c_str()); \ - } \ +#define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::UINT8, uint8_t, \ + __VA_ARGS__) \ + default: \ + FDASSERT(false, \ + "Invalid enum data type. Expect to accept data type INT32, " \ + "INT64, UINT8 but receive type %s.", \ + Str(__dtype__).c_str()); \ + } \ }() FASTDEPLOY_DECL std::vector