-
Notifications
You must be signed in to change notification settings - Fork 0
/
clip.py
134 lines (105 loc) · 4.83 KB
/
clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention
"""
File Name: clip.py
Description: This script is used to produce pre-trained CLIP weights (v1-5-pruned-emaonly.ckpt from https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) text encoder and prompt embedding
How to implement the original CLIP architecture model trained on?
load the model of v1-5-pruned-emaonly.ckpt
>>> models = model_loader.preload_models_from_standard_weights(model_file='../checkpoint/v1-5-pruned-emaonly.ckpt ', DEVICE='cpu')
>>> model['clip']
to get following model specification:
CLIP(
(embedding): CLIPEmbedding(
(token_embedding): Embedding(49408, 768)
)
(layers): ModuleList(
(0-11): 12 x CLIPLayer(
(layernorm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attention): SelfAttention(
(in_proj): Linear(in_features=768, out_features=2304, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(layernorm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(linear_1): Linear(in_features=768, out_features=3072, bias=True)
(linear_2): Linear(in_features=3072, out_features=768, bias=True)
)
)
(layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
Where the Embedding input dim comes from?
vocab.json ther are 49408 key values pairs mapping from words/special letters to numbers
downloaded from https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/tokenizer
How to use pre-trained CLIP for inference?
load the weights into Class CLIP
>>> CLIP.load_state_dict(state_dict['clip'], strict=True)
see how to convert state_dict from model_converter.py
"""
class CLIPEmbedding(nn.Module):
def __init__(self, n_vocab: int, n_embd: int, n_token: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_embd)
# A learnable weight matrix encodes the position information for each token
self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
def forward(self, tokens):
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
x = self.token_embedding(tokens)
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
x += self.position_embedding
return x
class CLIPLayer(nn.Module):
def __init__(self, n_head: int, n_embd: int):
super().__init__()
# Pre-attention norm
self.layernorm_1 = nn.LayerNorm(n_embd)
# Self attention
self.attention = SelfAttention(n_head, n_embd)
# Pre-FNN norm
self.layernorm_2 = nn.LayerNorm(n_embd)
# Feedforward layer
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
self.linear_2 = nn.Linear(4 * n_embd, n_embd)
def forward(self, x):
# (Batch_Size, Seq_Len, Dim)
residue = x
### SELF ATTENTION ###
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
x = self.layernorm_1(x)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
x = self.attention(x, causal_mask=True)
# (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
x += residue
### FEEDFORWARD LAYER ###
# Apply a feedforward layer where the hidden dimension is 4 times the embedding dimension.
residue = x
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
x = self.layernorm_2(x)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
x = self.linear_1(x)
# (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
# (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
x = self.linear_2(x)
# (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
x += residue
return x
class CLIP(nn.Module):
def __init__(self):
super().__init__()
self.embedding = CLIPEmbedding(49408, 768, 77)
self.layers = nn.ModuleList([
CLIPLayer(12, 768) for i in range(12)
])
self.layernorm = nn.LayerNorm(768)
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
state = self.embedding(tokens)
# Apply encoder layers similar to the Transformer's encoder.
for layer in self.layers:
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
state = layer(state)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
output = self.layernorm(state)
return output