Skip to content

Commit

Permalink
closes #118; fix seglfault in python operator
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Feb 9, 2024
1 parent f6d3158 commit ce7de20
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
37 changes: 37 additions & 0 deletions community/reproduce_vae_segfault.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn.functional as F

from diffusers import AutoencoderKL

from sfast.compilers.stable_diffusion_pipeline_compiler import (
compile_vae,
CompilationConfig,
)

device = torch.device("cuda:0")

SD_2_1_DIFFUSERS_MODEL = "stabilityai/stable-diffusion-2-1"
variant = {"variant": "fp16"}
vae_orig = AutoencoderKL.from_pretrained(
SD_2_1_DIFFUSERS_MODEL,
subfolder="vae",
torch_dtype=torch.float16,
**variant,
)

vae_orig.to(device)

sfast_config = CompilationConfig.Default()
sfast_config.enable_xformers = False
sfast_config.enable_triton = True
sfast_config.enable_cuda_graph = False
vae = compile_vae(vae_orig, sfast_config)

sample_imgs = torch.randn(4, 3, 128, 128, dtype=vae.dtype, device=device)
latents1 = torch.randn(4, 4, 16, 16, dtype=vae.dtype, device=device)

latents = vae.encode(sample_imgs).latent_dist.sample()

sample_imgs_dup = sample_imgs.clone().detach().requires_grad_(True)
latents2 = vae_orig.encode(sample_imgs_dup).latent_dist.sample()
print("Test done")
18 changes: 10 additions & 8 deletions src/sfast/csrc/jit/python_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@ void RegisterCustomPythonOperator(const std::string &schema,
auto arguments = parsed_schema.arguments();
auto returns = parsed_schema.returns();

std::shared_ptr<py::function> func_ptr(
std::shared_ptr<const py::function> func_ptr(
new py::function(py::reinterpret_borrow<const py::function>(
py::handle(const_cast<PyObject *>(py_callable.get())))),
py::handle(py_callable.get()))),
[](py::function *ptr) {
// Check if the current thread is holding the GIL
if (PyGILState_Check()) {
delete ptr;
} else {
py::gil_scoped_acquire gil;
delete ptr;
if (Py_IsInitialized()) {
// Check if the current thread is holding the GIL
if (PyGILState_Check()) {
delete ptr;
} else {
py::gil_scoped_acquire gil;
delete ptr;
}
}
});

Expand Down

0 comments on commit ce7de20

Please sign in to comment.