-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gemma #358
Add Gemma #358
Conversation
I'm not familiar with Bumblebee, so I may be doing something wrong, but when I tried out this branch my Phoenix app crashes when I try to load the model: defmodule NoEx.Application do
use Application
@hf_token "abcd..."
@impl true
def start(_type, _args) do
IO.inspect("NoEx.Application.start")
{:ok, nx_model_info} =
Bumblebee.load_model(
{:hf, "google/gemma-7b-it", [auth_token: @hf_token]},
spec_overrides: [num_labels: 10]
)
IO.inspect(nx_model_info)
{:ok, nx_tokenizer} =
Bumblebee.load_tokenizer({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})
IO.inspect(nx_tokenizer)
{:ok, nx_gen_config} =
Bumblebee.load_generation_config({:hf, "google/gemma-7b-it", [auth_token: @hf_token]})
IO.inspect(nx_gen_config)
children = [
NoExWeb.Telemetry,
NoEx.Repo,
{DNSCluster, query: Application.get_env(:no_ex, :dns_cluster_query) || :ignore},
{Phoenix.PubSub, name: NoEx.PubSub},
# Nx
{Nx.Serving, serving: nx_serving, name: NoEx.Serving, batch_timeout: 100},
# Start the Finch HTTP client for sending emails
{Finch, name: NoEx.Finch},
# Start a worker by calling: NoEx.Worker.start_link(arg)
# {NoEx.Worker, arg},
# Start to serve requests, typically the last entry
NoExWeb.Endpoint
]
# See https://hexdocs.pm/elixir/Supervisor.html
# for other strategies and supported options
opts = [strategy: :one_for_one, name: NoEx.Supervisor]
Supervisor.start_link(children, opts)
end
end It never makes it past the $ mix phx.server
Compiling 1 file (.ex)
"NoEx.Application.start"
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1709498689.304461 5486422 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[1] 74844 killed mix phx.server |
@seanmor5 the official checkpoint ties embeddings, so I changed loading to FTR I also used the config values from hf/transformers tests, which is even smaller, generally we want the tiny checkpoints to be as small as possible :) The Codefrom transformers import GemmaConfig, GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification
config = GemmaConfig(
# vocab_size=99,
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
is_decoder=False,
initializer_range=0.02,
pad_token_id=0,
head_dim=8,
)
for c in [GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification]:
name = c.__name__
c(config).save_pretrained(f"tmp/bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True) @seanmor5 I also added you to the |
Yeah I think it's fine to change |
@kurtome the "killed" log probably means that the OS killed the process because it was getting close to OOM. This is on CPU right? How much RAM do you have? You probably want to do |
Resolves #357
Gemma has
attention_bias
config, which is similar to ouruse_qkv_bias
but not really accurate because there is attention output bias too. I addeduse_attention_bias
instead, but wondering if we should change all instances touse_attention_bias
?