Skip to content

Commit

Permalink
Image generation: added TorchGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Dec 13, 2024
1 parent d17f716 commit 2947845
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 75 deletions.
4 changes: 4 additions & 0 deletions samples/python/image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Users can change the sample code and play with the following generation paramete
- Apply multiple different LoRA adapters and mix them with different blending coefficients
- (Image to image and inpainting) Play with `strength` parameter to control how initial image is noised and reduce number of inference steps

> [!NOTE]
> All samples use `openvino_genai.TorchGenerator` as an argument to `generate` call to align random numbers generation with Diffusers library where `torch.Generator` is used underhood.
> To have the same results in HuggingFace, pass manually created `torch.Generator(device='cpu').manual_seed(42)` to Diffusers generation pipelines to fix `seed` parameter.
## Download and convert the models and tokenizers

The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main():
guidance_scale=guidance_scale,
num_inference_steps=number_of_inference_steps_per_image,
num_images_per_prompt=1,
generator=openvino_genai.CppStdGenerator(42)
generator=openvino_genai.TorchGenerator(42)
)

image = Image.fromarray(image_tensor.data[0])
Expand Down
3 changes: 2 additions & 1 deletion samples/python/image_generation/image2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main():
image = read_image(args.image)

image_tensor = pipe.generate(args.prompt, image,
strength=0.8 # controls how initial image is noised after being converted to latent space. `1` means initial image is fully noised
strength=0.8, # controls how initial image is noised after being converted to latent space. `1` means initial image is fully noised
generator=openvino_genai.TorchGenerator(42)
)

image = Image.fromarray(image_tensor.data[0])
Expand Down
3 changes: 2 additions & 1 deletion samples/python/image_generation/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def main():
image = read_image(args.image)
mask_image = read_image(args.mask)

image_tensor = pipe.generate(args.prompt, image, mask_image)
image_tensor = pipe.generate(args.prompt, image, mask_image,
generator=openvino_genai.TorchGenerator(42))

image = Image.fromarray(image_tensor.data[0])
image.save("image.bmp")
Expand Down
23 changes: 4 additions & 19 deletions samples/python/image_generation/lora_text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,6 @@

import openvino as ov
import openvino_genai
import numpy as np
import sys


class Generator(openvino_genai.Generator):
def __init__(self, seed, mu=0.0, sigma=1.0):
openvino_genai.Generator.__init__(self)
np.random.seed(seed)
self.mu = mu
self.sigma = sigma

def next(self):
return np.random.normal(self.mu, self.sigma)


def image_write(path: str, image_tensor: ov.Tensor):
from PIL import Image
Expand Down Expand Up @@ -48,21 +34,20 @@ def main():
pipe = openvino_genai.Text2ImagePipeline(args.models_path, device, adapters=adapter_config)
print("Generating image with LoRA adapters applied, resulting image will be in lora.bmp")
image = pipe.generate(prompt,
generator=Generator(42),
width=512,
height=896,
num_inference_steps=20)
num_inference_steps=20,
generator=openvino_genai.TorchGenerator(42))

image_write("lora.bmp", image)
print("Generating image without LoRA adapters applied, resulting image will be in baseline.bmp")
image = pipe.generate(prompt,
# passing adapters in generate overrides adapters set in the constructor; openvino_genai.AdapterConfig() means no adapters
adapters=openvino_genai.AdapterConfig(),
generator=Generator(42),
width=512,
height=896,
num_inference_steps=20
)
num_inference_steps=20,
generator=openvino_genai.TorchGenerator(42))
image_write("baseline.bmp", image)


Expand Down
15 changes: 2 additions & 13 deletions samples/python/image_generation/text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@

import openvino_genai
from PIL import Image
import numpy as np

class Generator(openvino_genai.Generator):
def __init__(self, seed, mu=0.0, sigma=1.0):
openvino_genai.Generator.__init__(self)
np.random.seed(seed)
self.mu = mu
self.sigma = sigma

def next(self):
return np.random.normal(self.mu, self.sigma)


def main():
Expand All @@ -32,9 +21,9 @@ def main():
args.prompt,
width=512,
height=512,
num_inference_steps=20,
num_inference_steps=5,
num_images_per_prompt=1,
generator=Generator(42) # openvino_genai.CppStdGenerator can be used to have same images as C++ sample
generator=openvino_genai.TorchGenerator(42) # openvino_genai.CppStdGenerator can be used to have same images as C++ sample
)

image = Image.fromarray(image_tensor.data[0])
Expand Down
2 changes: 1 addition & 1 deletion src/python/openvino_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
if hasattr(os, "add_dll_directory"):
os.add_dll_directory(os.path.dirname(__file__))


from .py_openvino_genai import (
DecodedResults,
EncodedResults,
Expand Down Expand Up @@ -75,6 +74,7 @@
ImageGenerationConfig,
Generator,
CppStdGenerator,
TorchGenerator,
)

# Continuous batching
Expand Down
3 changes: 2 additions & 1 deletion src/python/openvino_genai/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from openvino_genai.py_openvino_genai import T5EncoderModel
from openvino_genai.py_openvino_genai import Text2ImagePipeline
from openvino_genai.py_openvino_genai import TokenizedInputs
from openvino_genai.py_openvino_genai import Tokenizer
from openvino_genai.py_openvino_genai import TorchGenerator
from openvino_genai.py_openvino_genai import UNet2DConditionModel
from openvino_genai.py_openvino_genai import VLMPipeline
from openvino_genai.py_openvino_genai import WhisperGenerationConfig
Expand All @@ -43,5 +44,5 @@ from openvino_genai.py_openvino_genai import WhisperRawPerfMetrics
from openvino_genai.py_openvino_genai import draft_model
import os as os
from . import py_openvino_genai
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'openvino', 'os', 'py_openvino_genai']
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'openvino', 'os', 'py_openvino_genai']
__version__: str = '2025.0.0.0'
18 changes: 14 additions & 4 deletions src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from __future__ import annotations
import openvino._pyopenvino
import os
import typing
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model']
__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model']
class Adapter:
"""
Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier.
Expand Down Expand Up @@ -801,7 +801,7 @@ class Image2ImagePipeline:
height: int - height of resulting images,
width: int - width of resulting images,
num_inference_steps: int - number of inference steps,
generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
adapters: LoRA adapters,
strength: strength for image to image generation. 1.0f means initial image is fully noised,
max_sequence_length: int - length of t5_encoder_model input
Expand Down Expand Up @@ -897,7 +897,7 @@ class InpaintingPipeline:
height: int - height of resulting images,
width: int - width of resulting images,
num_inference_steps: int - number of inference steps,
generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
adapters: LoRA adapters,
strength: strength for image to image generation. 1.0f means initial image is fully noised,
max_sequence_length: int - length of t5_encoder_model input
Expand Down Expand Up @@ -1564,7 +1564,7 @@ class Text2ImagePipeline:
height: int - height of resulting images,
width: int - width of resulting images,
num_inference_steps: int - number of inference steps,
generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
adapters: LoRA adapters,
strength: strength for image to image generation. 1.0f means initial image is fully noised,
max_sequence_length: int - length of t5_encoder_model input
Expand Down Expand Up @@ -1637,6 +1637,16 @@ class Tokenizer:
"""
Override a chat_template read from tokenizer_config.json.
"""
class TorchGenerator(CppStdGenerator):
"""
This class provides OpenVINO GenAI Generator wrapper for torch.Generator
"""
def __init__(self, seed: int) -> None:
...
def next(self) -> float:
...
def randn_tensor(self, shape: openvino._pyopenvino.Shape) -> openvino._pyopenvino.Tensor:
...
class UNet2DConditionModel:
"""
UNet2DConditionModel class.
Expand Down
98 changes: 76 additions & 22 deletions src/python/py_image_generation_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <pybind11/stl_bind.h>
#include <pybind11/stl/filesystem.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>

#include "openvino/genai/image_generation/text2image_pipeline.hpp"
#include "openvino/genai/image_generation/image2image_pipeline.hpp"
Expand All @@ -19,23 +20,7 @@
namespace py = pybind11;
namespace pyutils = ov::genai::pybind::utils;

namespace ov {
namespace genai {

/// Trampoline class to support inheritance from Generator in Python
class PyGenerator : public ov::genai::Generator {
public:
float next() override {
PYBIND11_OVERRIDE_PURE(float, Generator, next);
}

ov::Tensor randn_tensor(const ov::Shape& shape) override {
PYBIND11_OVERRIDE(ov::Tensor, Generator, randn_tensor, shape);
}
};

} // namespace genai
} // namespace ov
using namespace pybind11::literals;

namespace {

Expand All @@ -59,7 +44,7 @@ auto text2image_generate_docstring = R"(
height: int - height of resulting images,
width: int - width of resulting images,
num_inference_steps: int - number of inference steps,
generator: openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
generator: openvino_genai.TorchGenerator, openvino_genai.CppStdGenerator or class inherited from openvino_genai.Generator - random generator,
adapters: LoRA adapters,
strength: strength for image to image generation. 1.0f means initial image is fully noised,
max_sequence_length: int - length of t5_encoder_model input
Expand All @@ -68,7 +53,70 @@ auto text2image_generate_docstring = R"(
:rtype: ov.Tensor
)";

// Trampoline class to support inheritance from Generator in Python
class PyGenerator : public ov::genai::Generator {
public:
float next() override {
PYBIND11_OVERRIDE_PURE(float, Generator, next);
}

ov::Tensor randn_tensor(const ov::Shape& shape) override {
PYBIND11_OVERRIDE(ov::Tensor, Generator, randn_tensor, shape);
}
};

py::list to_py_list(const ov::Shape shape) {
py::list py_shape;
for (auto d : shape)
py_shape.append(d);

return py_shape;
}

class TorchGenerator : public ov::genai::CppStdGenerator {
py::module_ m_torch;
py::object m_torch_generator, m_float32;
py::object m_torch_tensor; // we need to hold torch.Tensor to avoid memory destruction
public:
explicit TorchGenerator(uint32_t seed) : CppStdGenerator(seed) {
try {
m_torch = py::module_::import("torch");
m_torch_generator = m_torch.attr("Generator")("device"_a="cpu").attr("manual_seed")(seed);
m_float32 = m_torch.attr("float32");
} catch (const py::error_already_set& e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
PyErr_WarnEx(PyExc_ImportWarning, "'torch' module is not installed. Random generation will fall back to 'CppStdGenerator'", 1);
} else {
// Re-throw other exceptions
throw;
}
}
}

float next() override {
if (m_torch) {
return m_torch.attr("randn")(1, "generator"_a=m_torch_generator, "dtype"_a=m_float32).attr("item")().cast<float>();
} else {
return CppStdGenerator::next();
}
}

ov::Tensor randn_tensor(const ov::Shape& shape) override {
if (m_torch) {
m_torch_tensor = m_torch.attr("randn")(to_py_list(shape), "generator"_a=m_torch_generator, "dtype"_a=m_float32);
py::object numpy_tensor = m_torch_tensor.attr("numpy")();
py::array numpy_array = py::cast<py::array>(numpy_tensor);

if (!numpy_array.dtype().is(py::dtype::of<float>())) {
throw std::runtime_error("Expected a NumPy array with dtype float32");
}

return ov::Tensor(ov::element::f32, shape, numpy_array.mutable_data());
} else {
return CppStdGenerator::randn_tensor(shape);
}
}
};

} // namespace

Expand All @@ -81,19 +129,25 @@ void init_flux_transformer_2d_model(py::module_& m);
void init_autoencoder_kl(py::module_& m);

void init_image_generation_pipelines(py::module_& m) {
py::class_<ov::genai::Generator, ov::genai::PyGenerator, std::shared_ptr<ov::genai::Generator>>(m, "Generator", "This class is used for storing pseudo-random generator.")
py::class_<ov::genai::Generator, ::PyGenerator, std::shared_ptr<ov::genai::Generator>>(m, "Generator", "This class is used for storing pseudo-random generator.")
.def(py::init<>());

py::class_<ov::genai::CppStdGenerator, ov::genai::Generator, std::shared_ptr<ov::genai::CppStdGenerator>>(m, "CppStdGenerator", "This class wraps std::mt19937 pseudo-random generator.")
.def(py::init([](
uint32_t seed
) {
.def(py::init([](uint32_t seed) {
return std::make_unique<ov::genai::CppStdGenerator>(seed);
}),
py::arg("seed"))
.def("next", &ov::genai::CppStdGenerator::next)
.def("randn_tensor", &ov::genai::CppStdGenerator::randn_tensor, py::arg("shape"));

py::class_<::TorchGenerator, ov::genai::CppStdGenerator, std::shared_ptr<::TorchGenerator>>(m, "TorchGenerator", "This class provides OpenVINO GenAI Generator wrapper for torch.Generator")
.def(py::init([](uint32_t seed) {
return std::make_unique<::TorchGenerator>(seed);
}),
py::arg("seed"))
.def("next", &::TorchGenerator::next)
.def("randn_tensor", &::TorchGenerator::randn_tensor, py::arg("shape"));

// init image generation models
init_clip_text_model(m);
init_clip_text_model_with_projection(m);
Expand Down
13 changes: 1 addition & 12 deletions tools/who_what_benchmark/whowhatbench/text2image_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,6 @@
}


class Generator(openvino_genai.Generator):
def __init__(self, seed, rng, mu=0.0, sigma=1.0):
openvino_genai.Generator.__init__(self)
self.mu = mu
self.sigma = sigma
self.rng = rng

def next(self):
return torch.randn(1, generator=self.rng, dtype=torch.float32).item()


@register_evaluator("text-to-image")
class Text2ImageEvaluator(BaseEvaluator):
def __init__(
Expand Down Expand Up @@ -171,7 +160,7 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
model,
prompt,
self.num_inference_steps,
generator=Generator(self.seed, rng) if self.is_genai else rng
generator=openvino_genai.TorchGenerator(self.seed) if self.is_genai else rng
)
image_path = os.path.join(image_dir, f"{i}.png")
image.save(image_path)
Expand Down

0 comments on commit 2947845

Please sign in to comment.