Skip to content

Commit

Permalink
Automatically detect diffusers params files (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 12, 2023
1 parent eca3735 commit d1783dc
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 72 deletions.
47 changes: 19 additions & 28 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ defmodule Bumblebee do
@tokenizer_special_tokens_filename "special_tokens_map.json"
@generation_filename "generation_config.json"
@scheduler_filename "scheduler_config.json"
@pytorch_params_filename "pytorch_model.bin"
@safetensors_params_filename "model.safetensors"

@params_filenames [
"pytorch_model.bin",
"diffusion_pytorch_model.bin",
"model.safetensors",
"diffusion_pytorch_model.safetensors"
]

@transformers_class_to_model %{
"AlbertForMaskedLM" => {Bumblebee.Text.Albert, :for_masked_language_modeling},
Expand Down Expand Up @@ -534,36 +539,22 @@ defmodule Bumblebee do
end

defp infer_params_filename(repo_files, nil = _filename) do
cond do
Map.has_key?(repo_files, @pytorch_params_filename) ->
{@pytorch_params_filename, false}

Map.has_key?(repo_files, @pytorch_params_filename <> ".index.json") ->
{@pytorch_params_filename, true}

Map.has_key?(repo_files, @safetensors_params_filename) ->
{@safetensors_params_filename, false}

Map.has_key?(repo_files, @safetensors_params_filename <> ".index.json") ->
{@safetensors_params_filename, true}

true ->
raise ArgumentError,
"none of the expected parameters files found in the repository." <>
" If the file exists under an unusual name, try specifying :params_filename"
end
Enum.find_value(@params_filenames, &lookup_params_filename(repo_files, &1)) ||
raise ArgumentError,
"none of the expected parameters files found in the repository." <>
" If the file exists under an unusual name, try specifying :params_filename"
end

defp infer_params_filename(repo_files, filename) do
cond do
Map.has_key?(repo_files, filename) ->
{filename, false}

Map.has_key?(repo_files, filename <> ".index.json") ->
{filename, true}
lookup_params_filename(repo_files, filename) ||
raise ArgumentError, "could not find file #{inspect(filename)} in the repository"
end

true ->
raise ArgumentError, "could not find file #{inspect(filename)} in the repository"
defp lookup_params_filename(repo_files, filename) do
cond do
Map.has_key?(repo_files, filename) -> {filename, false}
Map.has_key?(repo_files, filename <> ".index.json") -> {filename, true}
true -> nil
end
end

Expand Down
15 changes: 2 additions & 13 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
{:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
Expand Down
15 changes: 2 additions & 13 deletions notebooks/stable_diffusion.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,9 @@ Stable Diffusion is composed of several separate models and preprocessors, so we
repository_id = "CompVis/stable-diffusion-v1-4"

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})

{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})

{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)

{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
)

{:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
{:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
Expand Down
12 changes: 2 additions & 10 deletions test/bumblebee/diffusion/stable_diffusion_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,11 @@ defmodule Bumblebee.Diffusion.StableDiffusionTest do
repository_id = "bumblebee-testing/tiny-stable-diffusion"

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})

{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})

{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})

{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
)
Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)

{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})

Expand Down
3 changes: 1 addition & 2 deletions test/bumblebee/diffusion/unet_2d_conditional_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ defmodule Bumblebee.Diffusion.UNet2DConditionalTest do
test ":base" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "unet"}
)

assert %Bumblebee.Diffusion.UNet2DConditional{architecture: :base} = spec
Expand Down
9 changes: 3 additions & 6 deletions test/bumblebee/diffusion/vae_kl_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ defmodule Bumblebee.Diffusion.VaeKlTest do
test ":base" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"},
params_filename: "diffusion_pytorch_model.bin"
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"}
)

assert %Bumblebee.Diffusion.VaeKl{architecture: :base} = spec
Expand Down Expand Up @@ -39,8 +38,7 @@ defmodule Bumblebee.Diffusion.VaeKlTest do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
architecture: :decoder
)

assert %Bumblebee.Diffusion.VaeKl{architecture: :decoder} = spec
Expand Down Expand Up @@ -70,8 +68,7 @@ defmodule Bumblebee.Diffusion.VaeKlTest do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "hf-internal-testing/tiny-stable-diffusion-torch", subdir: "vae"},
architecture: :encoder,
params_filename: "diffusion_pytorch_model.bin"
architecture: :encoder
)

assert %Bumblebee.Diffusion.VaeKl{architecture: :encoder} = spec
Expand Down

0 comments on commit d1783dc

Please sign in to comment.