diff --git a/README.md b/README.md index 4892c86f10..680bc9bc65 100644 --- a/README.md +++ b/README.md @@ -194,12 +194,7 @@ import openvino_genai device = 'CPU' # GPU can be used as well pipe = openvino_genai.Text2ImagePipeline("./dreamlike_anime_1_0_ov/INT8", device) -image_tensor = pipe.generate( - "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting", - width=512, - height=512, - num_inference_steps=20 -) +image_tensor = pipe.generate("cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting") image = Image.fromarray(image_tensor.data[0]) image.save("image.bmp") @@ -218,10 +213,7 @@ int main(int argc, char* argv[]) { const std::string device = "CPU"; // GPU can be used as well ov::genai::Text2ImagePipeline pipe(models_path, device); - ov::Tensor image = pipe.generate(prompt, - ov::genai::width(512), - ov::genai::height(512), - ov::genai::num_inference_steps(20)); + ov::Tensor image = pipe.generate(prompt); imwrite("image.bmp", image, true); } diff --git a/samples/cpp/image_generation/README.md b/samples/cpp/image_generation/README.md index 8a5cc5aa19..f8dc21cc39 100644 --- a/samples/cpp/image_generation/README.md +++ b/samples/cpp/image_generation/README.md @@ -20,6 +20,10 @@ Users can change the sample code and play with the following generation paramete - Apply multiple different LoRA adapters and mix them with different blending coefficients - (Image to image and inpainting) Play with `strength` parameter to control how initial image is noised and reduce number of inference steps + +> [!NOTE] +> Image generated with HuggingFace / Optimum Intel is not the same generated by this C++ sample: C++ random generation with MT19937 results differ from `numpy.random.randn()` and `diffusers.utils.randn_tensor` (uses `torch.Generator` inside). So, it's expected that image generated by Diffusers and C++ versions provide different images, because latent images are initialize differently. + ## Download and convert the models and tokenizers The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. @@ -88,13 +92,6 @@ With adapter | Without adapter :---:|:---: ![](./lora.bmp) | ![](./baseline.bmp) - -## Note - -- Image generated with HuggingFace / Optimum Intel is not the same generated by this C++ sample: - -C++ random generation with MT19937 results differ from `numpy.random.randn()` and `diffusers.utils.randn_tensor`. So, it's expected that image generated by Python and C++ versions provide different images, because latent images are initialize differently. Users can implement their own random generator derived from `ov::genai::Generator` and pass it to `Text2ImagePipeline::generate` method. - ## Run text to image with multiple devices The `heterogeneous_stable_diffusion` sample demonstrates how a Text2ImagePipeline object can be created from individual subcomponents - scheduler, text encoder, unet, & vae decoder. This approach gives fine-grained control over the devices used to execute each stage of the stable diffusion pipeline. diff --git a/samples/cpp/image_generation/lora_text2image.cpp b/samples/cpp/image_generation/lora_text2image.cpp index 3fe4b74ff6..c1e6461db9 100644 --- a/samples/cpp/image_generation/lora_text2image.cpp +++ b/samples/cpp/image_generation/lora_text2image.cpp @@ -24,19 +24,19 @@ int32_t main(int32_t argc, char* argv[]) try { std::cout << "Generating image with LoRA adapters applied, resulting image will be in lora.bmp\n"; ov::Tensor image = pipe.generate(prompt, - ov::genai::generator(std::make_shared(42)), ov::genai::width(512), ov::genai::height(896), - ov::genai::num_inference_steps(20)); + ov::genai::num_inference_steps(20), + ov::genai::rng_seed(42)); imwrite("lora.bmp", image, true); std::cout << "Generating image without LoRA adapters applied, resulting image will be in baseline.bmp\n"; image = pipe.generate(prompt, ov::genai::adapters(), // passing adapters in generate overrides adapters set in the constructor; adapters() means no adapters - ov::genai::generator(std::make_shared(42)), ov::genai::width(512), ov::genai::height(896), - ov::genai::num_inference_steps(20)); + ov::genai::num_inference_steps(20), + ov::genai::rng_seed(42)); imwrite("baseline.bmp", image, true); return EXIT_SUCCESS; diff --git a/samples/python/image_generation/README.md b/samples/python/image_generation/README.md index 0ddf57d882..3e53f40fc4 100644 --- a/samples/python/image_generation/README.md +++ b/samples/python/image_generation/README.md @@ -20,6 +20,10 @@ Users can change the sample code and play with the following generation paramete - Apply multiple different LoRA adapters and mix them with different blending coefficients - (Image to image and inpainting) Play with `strength` parameter to control how initial image is noised and reduce number of inference steps +> [!NOTE] +> OpenVINO GenAI is written in C++ and uses `CppStdGenerator` random generator in Image Generation pipelines, while Diffusers library uses `torch.Generator` underhood. +> To have the same results with HuggingFace, pass manually created `torch.Generator(device='cpu').manual_seed(seed)` to Diffusers generation pipelines and `openvino_genai.TorchGenerator(seed)` to OpenVINO GenAI pipelines as value for `generator` kwarg. + ## Download and convert the models and tokenizers The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. @@ -41,7 +45,7 @@ Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pi Prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting` - ![](./text2image.bmp) + ![](./../../cpp/image_generation/512x512.bmp) ### Run with callback @@ -85,7 +89,7 @@ Check the difference: With adapter | Without adapter :---:|:---: -![](./lora.bmp) | ![](./baseline.bmp) +![](./../../cpp/image_generation/lora.bmp) | ![](./../../cpp/image_generation/baseline.bmp) ## Run text to image with multiple devices diff --git a/samples/python/image_generation/baseline.bmp b/samples/python/image_generation/baseline.bmp deleted file mode 100644 index 1501f5960e..0000000000 --- a/samples/python/image_generation/baseline.bmp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ea0b60b64c4448448140a3cfb5e8609248ad35abd484ace1467d832e6966c941 -size 1376310 diff --git a/samples/python/image_generation/heterogeneous_stable_diffusion.py b/samples/python/image_generation/heterogeneous_stable_diffusion.py index b1a2f9d5de..18f150816e 100644 --- a/samples/python/image_generation/heterogeneous_stable_diffusion.py +++ b/samples/python/image_generation/heterogeneous_stable_diffusion.py @@ -101,8 +101,7 @@ def main(): height=height, guidance_scale=guidance_scale, num_inference_steps=number_of_inference_steps_per_image, - num_images_per_prompt=1, - generator=openvino_genai.CppStdGenerator(42) + num_images_per_prompt=1 ) image = Image.fromarray(image_tensor.data[0]) diff --git a/samples/python/image_generation/lora.bmp b/samples/python/image_generation/lora.bmp deleted file mode 100644 index a0aaedb930..0000000000 --- a/samples/python/image_generation/lora.bmp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:804bb8d49f1702422abf57c300af75fe75acbef60a9cf8ad5cfc9262b7532c95 -size 1376310 diff --git a/samples/python/image_generation/lora_text2image.py b/samples/python/image_generation/lora_text2image.py index 95e31ca0ea..6a46099dc2 100644 --- a/samples/python/image_generation/lora_text2image.py +++ b/samples/python/image_generation/lora_text2image.py @@ -6,20 +6,6 @@ import openvino as ov import openvino_genai -import numpy as np -import sys - - -class Generator(openvino_genai.Generator): - def __init__(self, seed, mu=0.0, sigma=1.0): - openvino_genai.Generator.__init__(self) - np.random.seed(seed) - self.mu = mu - self.sigma = sigma - - def next(self): - return np.random.normal(self.mu, self.sigma) - def image_write(path: str, image_tensor: ov.Tensor): from PIL import Image @@ -46,23 +32,23 @@ def main(): # LoRA adapters passed to the constructor will be activated by default in next generates pipe = openvino_genai.Text2ImagePipeline(args.models_path, device, adapters=adapter_config) + print("Generating image with LoRA adapters applied, resulting image will be in lora.bmp") image = pipe.generate(prompt, - generator=Generator(42), width=512, height=896, - num_inference_steps=20) + num_inference_steps=20, + rng_seed=42) image_write("lora.bmp", image) print("Generating image without LoRA adapters applied, resulting image will be in baseline.bmp") image = pipe.generate(prompt, # passing adapters in generate overrides adapters set in the constructor; openvino_genai.AdapterConfig() means no adapters adapters=openvino_genai.AdapterConfig(), - generator=Generator(42), width=512, height=896, - num_inference_steps=20 - ) + num_inference_steps=20, + rng_seed=42) image_write("baseline.bmp", image) diff --git a/samples/python/image_generation/text2image.bmp b/samples/python/image_generation/text2image.bmp deleted file mode 100644 index 54974556a4..0000000000 --- a/samples/python/image_generation/text2image.bmp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7c150896ec84f64d4f0cacd67f8f277e08d3ebb1c9a756d43fc80944db7a2ed4 -size 786486 diff --git a/samples/python/image_generation/text2image.py b/samples/python/image_generation/text2image.py index 95d8c68e82..cba1eefd1d 100644 --- a/samples/python/image_generation/text2image.py +++ b/samples/python/image_generation/text2image.py @@ -6,17 +6,6 @@ import openvino_genai from PIL import Image -import numpy as np - -class Generator(openvino_genai.Generator): - def __init__(self, seed, mu=0.0, sigma=1.0): - openvino_genai.Generator.__init__(self) - np.random.seed(seed) - self.mu = mu - self.sigma = sigma - - def next(self): - return np.random.normal(self.mu, self.sigma) def main(): @@ -33,9 +22,7 @@ def main(): width=512, height=512, num_inference_steps=20, - num_images_per_prompt=1, - generator=Generator(42) # openvino_genai.CppStdGenerator can be used to have same images as C++ sample - ) + num_images_per_prompt=1) image = Image.fromarray(image_tensor.data[0]) image.save("image.bmp") diff --git a/src/cpp/include/openvino/genai/generation_config.hpp b/src/cpp/include/openvino/genai/generation_config.hpp index 2402f57fba..9d79240aa8 100644 --- a/src/cpp/include/openvino/genai/generation_config.hpp +++ b/src/cpp/include/openvino/genai/generation_config.hpp @@ -67,9 +67,9 @@ enum class StopCriteria { EARLY, HEURISTIC, NEVER }; * @param top_k the number of highest probability vocabulary tokens to keep for top-k-filtering. * @param do_sample whether or not to use multinomial random sampling that add up to `top_p` or higher are kept. * @param repetition_penalty the parameter for repetition penalty. 1.0 means no penalty. - * @param presence_penalty reduces absolute log prob if the token was generated at least once. Ignored for non continuous batching. - * @param frequency_penalty reduces absolute log prob as many times as the token was generated. Ignored for non continuous batching. - * @param rng_seed initializes random generator. Ignored for non continuous batching. + * @param presence_penalty reduces absolute log prob if the token was generated at least once. + * @param frequency_penalty reduces absolute log prob as many times as the token was generated. + * @param rng_seed initializes random generator. * * Speculative decoding parameters: * @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of static strategy candidates number update. @@ -174,7 +174,7 @@ static constexpr ov::Property repetition_penalty{"repetition_penalty"}; static constexpr ov::Property eos_token_id{"eos_token_id"}; static constexpr ov::Property presence_penalty{"presence_penalty"}; static constexpr ov::Property frequency_penalty{"frequency_penalty"}; -static constexpr ov::Property rng_seed{"rng_seed"}; +extern OPENVINO_GENAI_EXPORTS ov::Property rng_seed; static constexpr ov::Property assistant_confidence_threshold{"assistant_confidence_threshold"}; static constexpr ov::Property num_assistant_tokens{"num_assistant_tokens"}; diff --git a/src/cpp/include/openvino/genai/image_generation/generation_config.hpp b/src/cpp/include/openvino/genai/image_generation/generation_config.hpp index 50e576466d..bd7073520a 100644 --- a/src/cpp/include/openvino/genai/image_generation/generation_config.hpp +++ b/src/cpp/include/openvino/genai/image_generation/generation_config.hpp @@ -39,6 +39,12 @@ class OPENVINO_GENAI_EXPORTS Generator { */ virtual ov::Tensor randn_tensor(const ov::Shape& shape); + /** + * Sets a new initial seed value to random generator + * @param new_seed A new seed value + */ + virtual void seed(size_t new_seed) = 0; + /** * Default dtor defined to ensure working RTTI. */ @@ -58,9 +64,11 @@ class OPENVINO_GENAI_EXPORTS CppStdGenerator : public Generator { virtual float next() override; + virtual void seed(size_t new_seed) override; + private: - std::mt19937 gen; - std::normal_distribution normal; + std::mt19937 m_gen; + std::normal_distribution m_normal; }; /** @@ -81,9 +89,17 @@ struct OPENVINO_GENAI_EXPORTS ImageGenerationConfig { size_t num_images_per_prompt = 1; /** - * Random generator to initial latents, add noise to initial images in case of image to image / inpainting pipelines + * Random generator to initialize latents, add noise to initial images in case of image to image / inpainting pipelines + * By default, random generator is initialized as `CppStdGenerator(generation_config.rng_seed)` + * @note If `generator` is specified, it has higher priority than `rng_seed` parameter. + */ + std::shared_ptr generator = nullptr; + + /** + * Seed for random generator + * @note If `generator` is specified, it has higher priority than `rng_seed` parameter. */ - std::shared_ptr generator = std::make_shared(42); + size_t rng_seed = 42; float guidance_scale = 7.5f; int64_t height = -1; @@ -91,7 +107,7 @@ struct OPENVINO_GENAI_EXPORTS ImageGenerationConfig { size_t num_inference_steps = 50; /** - * Max sequence lenght for T4 encoder / tokenizer used in SD3 / FLUX models + * Max sequence length for T5 encoder / tokenizer used in SD3 / FLUX models */ int max_sequence_length = -1; @@ -203,6 +219,12 @@ static constexpr ov::Property strength{"strength"}; */ static constexpr ov::Property> generator{"generator"}; +/** + * Seed for random generator + * @note If `generator` is specified, it has higher priority than `rng_seed` parameter. + */ +extern OPENVINO_GENAI_EXPORTS ov::Property rng_seed; + /** * This parameters limits max sequence length for T5 encoder for SD3 and FLUX models. * T5 tokenizer output is padded with pad tokens to 'max_sequence_length' within a pipeline. diff --git a/src/cpp/src/generation_config.cpp b/src/cpp/src/generation_config.cpp index 0829e8376a..189cfeded7 100644 --- a/src/cpp/src/generation_config.cpp +++ b/src/cpp/src/generation_config.cpp @@ -14,6 +14,8 @@ namespace ov { namespace genai { +ov::Property rng_seed{"rng_seed"}; + GenerationConfig::GenerationConfig(const std::filesystem::path& json_path) { using utils::read_json_param; @@ -21,7 +23,7 @@ GenerationConfig::GenerationConfig(const std::filesystem::path& json_path) { OPENVINO_ASSERT(f.is_open(), "Failed to open '", json_path, "' with generation config"); nlohmann::json data = nlohmann::json::parse(f); - + read_json_param(data, "max_new_tokens", max_new_tokens); read_json_param(data, "max_length", max_length); // note that ignore_eos is not present in HF GenerationConfig @@ -103,6 +105,9 @@ void GenerationConfig::update_generation_config(const ov::AnyMap& config_map) { read_anymap_param(config_map, "echo", echo); read_anymap_param(config_map, "logprobs", logprobs); read_anymap_param(config_map, "adapters", adapters); + + // TODO: add support of 'generator' property similar to Image generation + read_anymap_param(config_map, "rng_seed", rng_seed); } size_t GenerationConfig::get_max_new_tokens(size_t prompt_length) const { diff --git a/src/cpp/src/image_generation/flux_pipeline.hpp b/src/cpp/src/image_generation/flux_pipeline.hpp index 716ba6b61b..e74cd441ce 100644 --- a/src/cpp/src/image_generation/flux_pipeline.hpp +++ b/src/cpp/src/image_generation/flux_pipeline.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include "image_generation/diffusion_pipeline.hpp" #include "image_generation/numpy_utils.hpp" diff --git a/src/cpp/src/image_generation/generation_config.cpp b/src/cpp/src/image_generation/generation_config.cpp index 938034f628..ab098fabe5 100644 --- a/src/cpp/src/image_generation/generation_config.cpp +++ b/src/cpp/src/image_generation/generation_config.cpp @@ -27,11 +27,15 @@ ov::Tensor Generator::randn_tensor(const ov::Shape& shape) { } CppStdGenerator::CppStdGenerator(uint32_t seed) - : gen(seed), normal(0.0f, 1.0f) { + : m_gen(seed), m_normal(0.0f, 1.0f) { } float CppStdGenerator::next() { - return normal(gen); + return m_normal(m_gen); +} + +void CppStdGenerator::seed(size_t new_seed) { + m_gen.seed(new_seed); } // @@ -55,7 +59,6 @@ void ImageGenerationConfig::update_generation_config(const ov::AnyMap& propertie read_anymap_param(properties, "negative_prompt_2", negative_prompt_2); read_anymap_param(properties, "negative_prompt_3", negative_prompt_3); read_anymap_param(properties, "num_images_per_prompt", num_images_per_prompt); - read_anymap_param(properties, "generator", generator); read_anymap_param(properties, "guidance_scale", guidance_scale); read_anymap_param(properties, "height", height); read_anymap_param(properties, "width", width); @@ -64,6 +67,25 @@ void ImageGenerationConfig::update_generation_config(const ov::AnyMap& propertie read_anymap_param(properties, "adapters", adapters); read_anymap_param(properties, "max_sequence_length", max_sequence_length); + // 'generator' has higher priority than 'seed' parameter + const bool have_generator_param = properties.find(ov::genai::generator.name()) != properties.end(); + if (have_generator_param) { + read_anymap_param(properties, "generator", generator); + } else { + read_anymap_param(properties, "rng_seed", rng_seed); + + // initialize random generator with a given seed value + if (!generator) { + generator = std::make_shared(rng_seed); + } + + const bool have_rng_seed = properties.find(ov::genai::rng_seed.name()) != properties.end(); + if (have_rng_seed) { + // we need to change seed as an user have specified it manually + generator->seed(rng_seed); + } + } + validate(); } diff --git a/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp index 18a3e0346f..e3e720109d 100644 --- a/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include "image_generation/diffusion_pipeline.hpp" #include "image_generation/numpy_utils.hpp" @@ -453,11 +452,6 @@ class StableDiffusion3Pipeline : public DiffusionPipeline { check_inputs(generation_config, initial_image); - if (generation_config.generator == nullptr) { - uint32_t seed = time(NULL); - generation_config.generator = std::make_shared(seed); - } - // 3. Prepare timesteps m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); std::vector timesteps = m_scheduler->get_float_timesteps(); diff --git a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp index 4afbd3ac78..7549b67919 100644 --- a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp @@ -3,7 +3,6 @@ #pragma once -#include #include #include @@ -333,11 +332,6 @@ class StableDiffusionPipeline : public DiffusionPipeline { set_lora_adapters(generation_config.adapters); - if (generation_config.generator == nullptr) { - uint32_t seed = time(NULL); - generation_config.generator = std::make_shared(seed); - } - m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); std::vector timesteps = m_scheduler->get_timesteps(); diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index ca7c2c0b32..470ddd0cd8 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -11,7 +11,6 @@ if hasattr(os, "add_dll_directory"): os.add_dll_directory(os.path.dirname(__file__)) - from .py_openvino_genai import ( DecodedResults, EncodedResults, @@ -75,6 +74,7 @@ ImageGenerationConfig, Generator, CppStdGenerator, + TorchGenerator, ) # Continuous batching diff --git a/src/python/openvino_genai/__init__.pyi b/src/python/openvino_genai/__init__.pyi index 4d74e17588..187e0a0a06 100644 --- a/src/python/openvino_genai/__init__.pyi +++ b/src/python/openvino_genai/__init__.pyi @@ -34,6 +34,7 @@ from openvino_genai.py_openvino_genai import T5EncoderModel from openvino_genai.py_openvino_genai import Text2ImagePipeline from openvino_genai.py_openvino_genai import TokenizedInputs from openvino_genai.py_openvino_genai import Tokenizer +from openvino_genai.py_openvino_genai import TorchGenerator from openvino_genai.py_openvino_genai import UNet2DConditionModel from openvino_genai.py_openvino_genai import VLMPipeline from openvino_genai.py_openvino_genai import WhisperGenerationConfig @@ -43,5 +44,5 @@ from openvino_genai.py_openvino_genai import WhisperRawPerfMetrics from openvino_genai.py_openvino_genai import draft_model import os as os from . import py_openvino_genai -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'openvino', 'os', 'py_openvino_genai'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'openvino', 'os', 'py_openvino_genai'] __version__: str = '2025.0.0.0' diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 829d4844e8..8b8eb76b12 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -5,7 +5,7 @@ from __future__ import annotations import openvino._pyopenvino import os import typing -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -398,6 +398,8 @@ class CppStdGenerator(Generator): ... def randn_tensor(self, shape: openvino._pyopenvino.Shape) -> openvino._pyopenvino.Tensor: ... + def seed(self, new_seed: int) -> None: + ... class DecodedResults: """ @@ -804,7 +806,8 @@ class Image2ImagePipeline: height: int - height of resulting images, width: int - width of resulting images, num_inference_steps: int - number of inference steps, - generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, + rng_seed: int - a seed for random numbers generator, + generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, adapters: LoRA adapters, strength: strength for image to image generation. 1.0f means initial image is fully noised, max_sequence_length: int - length of t5_encoder_model input @@ -836,6 +839,7 @@ class ImageGenerationConfig: num_inference_steps: int prompt_2: str | None prompt_3: str | None + rng_seed: int strength: float width: int def __init__(self) -> None: @@ -903,7 +907,8 @@ class InpaintingPipeline: height: int - height of resulting images, width: int - width of resulting images, num_inference_steps: int - number of inference steps, - generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, + rng_seed: int - a seed for random numbers generator, + generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, adapters: LoRA adapters, strength: strength for image to image generation. 1.0f means initial image is fully noised, max_sequence_length: int - length of t5_encoder_model input @@ -1576,7 +1581,8 @@ class Text2ImagePipeline: height: int - height of resulting images, width: int - width of resulting images, num_inference_steps: int - number of inference steps, - generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, + rng_seed: int - a seed for random numbers generator, + generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, adapters: LoRA adapters, strength: strength for image to image generation. 1.0f means initial image is fully noised, max_sequence_length: int - length of t5_encoder_model input @@ -1649,6 +1655,18 @@ class Tokenizer: """ Override a chat_template read from tokenizer_config.json. """ +class TorchGenerator(CppStdGenerator): + """ + This class provides OpenVINO GenAI Generator wrapper for torch.Generator + """ + def __init__(self, seed: int) -> None: + ... + def next(self) -> float: + ... + def randn_tensor(self, shape: openvino._pyopenvino.Shape) -> openvino._pyopenvino.Tensor: + ... + def seed(self, new_seed: int) -> None: + ... class UNet2DConditionModel: """ UNet2DConditionModel class. diff --git a/src/python/py_image_generation_pipelines.cpp b/src/python/py_image_generation_pipelines.cpp index 55be1708c1..da6ce6d21b 100644 --- a/src/python/py_image_generation_pipelines.cpp +++ b/src/python/py_image_generation_pipelines.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "openvino/genai/image_generation/text2image_pipeline.hpp" #include "openvino/genai/image_generation/image2image_pipeline.hpp" @@ -19,23 +20,7 @@ namespace py = pybind11; namespace pyutils = ov::genai::pybind::utils; -namespace ov { -namespace genai { - -/// Trampoline class to support inheritance from Generator in Python -class PyGenerator : public ov::genai::Generator { -public: - float next() override { - PYBIND11_OVERRIDE_PURE(float, Generator, next); - } - - ov::Tensor randn_tensor(const ov::Shape& shape) override { - PYBIND11_OVERRIDE(ov::Tensor, Generator, randn_tensor, shape); - } -}; - -} // namespace genai -} // namespace ov +using namespace pybind11::literals; namespace { @@ -59,7 +44,8 @@ auto text2image_generate_docstring = R"( height: int - height of resulting images, width: int - width of resulting images, num_inference_steps: int - number of inference steps, - generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, + rng_seed: int - a seed for random numbers generator, + generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator, adapters: LoRA adapters, strength: strength for image to image generation. 1.0f means initial image is fully noised, max_sequence_length: int - length of t5_encoder_model input @@ -68,7 +54,102 @@ auto text2image_generate_docstring = R"( :rtype: ov.Tensor )"; +// Trampoline class to support inheritance from Generator in Python +class PyGenerator : public ov::genai::Generator { +public: + float next() override { + PYBIND11_OVERRIDE_PURE(float, Generator, next); + } + + ov::Tensor randn_tensor(const ov::Shape& shape) override { + PYBIND11_OVERRIDE(ov::Tensor, Generator, randn_tensor, shape); + } + + void seed(size_t new_seed) override { + PYBIND11_OVERRIDE_PURE(void, Generator, seed, new_seed); + } +}; + +py::list to_py_list(const ov::Shape shape) { + py::list py_shape; + for (auto d : shape) + py_shape.append(d); + + return py_shape; +} +class TorchGenerator : public ov::genai::CppStdGenerator { + py::module_ m_torch; + py::object m_torch_generator, m_float32; + + void create_torch_generator(size_t seed) { + m_torch_generator = m_torch.attr("Generator")("device"_a="cpu").attr("manual_seed")(seed); + } +public: + explicit TorchGenerator(uint32_t seed) : CppStdGenerator(seed) { + try { + m_torch = py::module_::import("torch"); + } catch (const py::error_already_set& e) { + if (e.matches(PyExc_ModuleNotFoundError)) { + throw std::runtime_error("The 'torch' package is not installed. Please, call 'pip install torch' or use 'rng_seed' parameter."); + } else { + // Re-throw other exceptions + throw; + } + } + + m_float32 = m_torch.attr("float32"); + create_torch_generator(seed); + } + + float next() override { + return m_torch.attr("randn")(1, "generator"_a=m_torch_generator, "dtype"_a=m_float32).attr("item")().cast(); + } + + ov::Tensor randn_tensor(const ov::Shape& shape) override { + py::object torch_tensor = m_torch.attr("randn")(to_py_list(shape), "generator"_a=m_torch_generator, "dtype"_a=m_float32); + py::object numpy_tensor = torch_tensor.attr("numpy")(); + py::array numpy_array = py::cast(numpy_tensor); + + if (!numpy_array.dtype().is(py::dtype::of())) { + throw std::runtime_error("Expected a NumPy array with dtype float32"); + } + + class TorchTensorAllocator { + size_t m_total_size; + void * m_mutable_data; + py::object m_torch_tensor; // we need to hold torch.Tensor to avoid memory destruction + + public: + TorchTensorAllocator(size_t total_size, void * mutable_data, py::object torch_tensor) : + m_total_size(total_size), m_mutable_data(mutable_data), m_torch_tensor(torch_tensor) { } + + void* allocate(size_t bytes, size_t) const { + if (m_total_size == bytes) { + return m_mutable_data; + } + throw std::runtime_error{"Unexpected number of bytes was requested to allocate."}; + } + + void deallocate(void*, size_t bytes, size_t) { + if (m_total_size != bytes) { + throw std::runtime_error{"Unexpected number of bytes was requested to deallocate."}; + } + } + + bool is_equal(const TorchTensorAllocator& other) const noexcept { + return this == &other; + } + }; + + return ov::Tensor(ov::element::f32, shape, + TorchTensorAllocator(ov::shape_size(shape) * ov::element::f32.size(), numpy_array.mutable_data(), torch_tensor)); + } + + void seed(size_t new_seed) override { + create_torch_generator(new_seed); + } +}; } // namespace @@ -81,16 +162,24 @@ void init_flux_transformer_2d_model(py::module_& m); void init_autoencoder_kl(py::module_& m); void init_image_generation_pipelines(py::module_& m) { - py::class_>(m, "Generator", "This class is used for storing pseudo-random generator.") + py::class_>(m, "Generator", "This class is used for storing pseudo-random generator.") .def(py::init<>()); py::class_>(m, "CppStdGenerator", "This class wraps std::mt19937 pseudo-random generator.") .def(py::init([](uint32_t seed) { return std::make_unique(seed); - }), - py::arg("seed")) + }), py::arg("seed")) .def("next", &ov::genai::CppStdGenerator::next) - .def("randn_tensor", &ov::genai::CppStdGenerator::randn_tensor, py::arg("shape")); + .def("randn_tensor", &ov::genai::CppStdGenerator::randn_tensor, py::arg("shape")) + .def("seed", &ov::genai::CppStdGenerator::seed, py::arg("new_seed")); + + py::class_<::TorchGenerator, ov::genai::CppStdGenerator, std::shared_ptr<::TorchGenerator>>(m, "TorchGenerator", "This class provides OpenVINO GenAI Generator wrapper for torch.Generator") + .def(py::init([](uint32_t seed) { + return std::make_unique<::TorchGenerator>(seed); + }), py::arg("seed")) + .def("next", &::TorchGenerator::next) + .def("randn_tensor", &::TorchGenerator::randn_tensor, py::arg("shape")) + .def("seed", &::TorchGenerator::seed, py::arg("new_seed")); // init image generation models init_clip_text_model(m); @@ -122,6 +211,7 @@ void init_image_generation_pipelines(py::module_& m) { .def_readwrite("negative_prompt_2", &ov::genai::ImageGenerationConfig::negative_prompt_2) .def_readwrite("negative_prompt_3", &ov::genai::ImageGenerationConfig::negative_prompt_3) .def_readwrite("generator", &ov::genai::ImageGenerationConfig::generator) + .def_readwrite("rng_seed", &ov::genai::ImageGenerationConfig::rng_seed) .def_readwrite("guidance_scale", &ov::genai::ImageGenerationConfig::guidance_scale) .def_readwrite("height", &ov::genai::ImageGenerationConfig::height) .def_readwrite("width", &ov::genai::ImageGenerationConfig::width) diff --git a/tools/who_what_benchmark/whowhatbench/text2image_evaluator.py b/tools/who_what_benchmark/whowhatbench/text2image_evaluator.py index 1ff7ff5e21..0cced117e4 100644 --- a/tools/who_what_benchmark/whowhatbench/text2image_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text2image_evaluator.py @@ -27,17 +27,6 @@ } -class Generator(openvino_genai.Generator): - def __init__(self, seed, rng, mu=0.0, sigma=1.0): - openvino_genai.Generator.__init__(self) - self.mu = mu - self.sigma = sigma - self.rng = rng - - def next(self): - return torch.randn(1, generator=self.rng, dtype=torch.float32).item() - - @register_evaluator("text-to-image") class Text2ImageEvaluator(BaseEvaluator): def __init__( @@ -171,7 +160,7 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None): model, prompt, self.num_inference_steps, - generator=Generator(self.seed, rng) if self.is_genai else rng + generator=openvino_genai.TorchGenerator(self.seed) if self.is_genai else rng ) image_path = os.path.join(image_dir, f"{i}.png") image.save(image_path)