Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma #358

Merged
merged 2 commits into from
Mar 4, 2024
Merged

Add Gemma #358

merged 2 commits into from
Mar 4, 2024

Conversation

seanmor5
Copy link
Contributor

@seanmor5 seanmor5 commented Mar 2, 2024

Resolves #357

Gemma has attention_bias config, which is similar to our use_qkv_bias but not really accurate because there is attention output bias too. I added use_attention_bias instead, but wondering if we should change all instances to use_attention_bias ?

@kurtome
Copy link

kurtome commented Mar 3, 2024

I'm not familiar with Bumblebee, so I may be doing something wrong, but when I tried out this branch my Phoenix app crashes when I try to load the model:

defmodule NoEx.Application do
  use Application

  @hf_token "abcd..."

  @impl true
  def start(_type, _args) do
    IO.inspect("NoEx.Application.start")

    {:ok, nx_model_info} =
      Bumblebee.load_model(
        {:hf, "google/gemma-7b-it", [auth_token: @hf_token]},
        spec_overrides: [num_labels: 10]
      )
          IO.inspect(nx_model_info)

    {:ok, nx_tokenizer} =
      Bumblebee.load_tokenizer({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})

    IO.inspect(nx_tokenizer)

    {:ok, nx_gen_config} =
      Bumblebee.load_generation_config({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})

    IO.inspect(nx_gen_config)

    children = [
      NoExWeb.Telemetry,
      NoEx.Repo,
      {DNSCluster, query: Application.get_env(:no_ex, :dns_cluster_query) || :ignore},
      {Phoenix.PubSub, name: NoEx.PubSub},
      # Nx
      {Nx.Serving, serving: nx_serving, name: NoEx.Serving, batch_timeout: 100},
      # Start the Finch HTTP client for sending emails
      {Finch, name: NoEx.Finch},
      # Start a worker by calling: NoEx.Worker.start_link(arg)
      # {NoEx.Worker, arg},
      # Start to serve requests, typically the last entry
      NoExWeb.Endpoint
    ]

    # See https://hexdocs.pm/elixir/Supervisor.html
    # for other strategies and supported options
    opts = [strategy: :one_for_one, name: NoEx.Supervisor]
    Supervisor.start_link(children, opts)
  end
end

It never makes it past the Bumblebee.load_model line.

$ mix phx.server

Compiling 1 file (.ex)
"NoEx.Application.start"
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1709498689.304461 5486422 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[1]    74844 killed     mix phx.server

@jonatanklosko
Copy link
Member

@seanmor5 the official checkpoint ties embeddings, so I changed loading to "language_modeling_head.output" => "model.embed_tokens" (we could add config, but unlikely there's an untied version, and we want to actually address it eventually). I generated the tiny config to not include the tied embeddings.

FTR I also used the config values from hf/transformers tests, which is even smaller, generally we want the tiny checkpoints to be as small as possible :) The utils/create_dummy_models.py in hf/transformers didn't work for me (also didn't for llama), though it did for Bert in the past. So instead of digging into this too much, I just created it by hand:

Code
from transformers import GemmaConfig, GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification

config = GemmaConfig(
  # vocab_size=99,
  vocab_size=1024,
  hidden_size=32,
  num_hidden_layers=2,
  num_attention_heads=4,
  num_key_value_heads=2,
  intermediate_size=37,
  hidden_act="gelu",
  hidden_dropout_prob=0.1,
  attention_probs_dropout_prob=0.1,
  max_position_embeddings=512,
  type_vocab_size=16,
  is_decoder=False,
  initializer_range=0.02,
  pad_token_id=0,
  head_dim=8,
)

for c in [GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification]:
  name = c.__name__
  c(config).save_pretrained(f"tmp/bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True)

@seanmor5 I also added you to the bumblebee-testing org, so you don't have to push random repos into your account :p

@jonatanklosko
Copy link
Member

I added use_attention_bias instead, but wondering if we should change all instances to use_attention_bias?

Yeah I think it's fine to change use_qkv_bias to use_attention_bias for consistent naming, and the model implementation passes it either to qkv or to qkvo. I will do this in a separate commit!

@jonatanklosko jonatanklosko merged commit 20855ad into main Mar 4, 2024
2 checks passed
@jonatanklosko jonatanklosko deleted the sm-gemma branch March 4, 2024 09:44
@jonatanklosko
Copy link
Member

@kurtome the "killed" log probably means that the OS killed the process because it was getting close to OOM. This is on CPU right? How much RAM do you have? You probably want to do load_model(..., type: :f16) to reduce the memory usage, though 7b model may be too much for the CPU to run in a reasonable speed, you can try the 2b checkpoint too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for "google/gemma-7b-it"
3 participants