forked from hkproj/pytorch-paligemma
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
38 lines (30 loc) · 1.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from modeling_gemma import PaliGemmaForConditionalGeneration, PaliGemmaConfig
from transformers import AutoTokenizer
import json
import glob
from safetensors import safe_open
from typing import Tuple
import os
def load_hf_model(model_path: str, device: str) -> Tuple[PaliGemmaForConditionalGeneration, AutoTokenizer]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
assert tokenizer.padding_side == "right"
# Find all the *.safetensors files
safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
# ... and load them one by one in the tensors dictionary
tensors = {}
for safetensors_file in safetensors_files:
with safe_open(safetensors_file, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
# Load the model's config
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config_file = json.load(f)
config = PaliGemmaConfig(**model_config_file)
# Create the model using the configuration
model = PaliGemmaForConditionalGeneration(config).to(device)
# Load the state dict of the model
model.load_state_dict(tensors, strict=False)
# Tie weights
model.tie_weights()
return (model, tokenizer)