-
Notifications
You must be signed in to change notification settings - Fork 896
/
linear_probe.py
56 lines (41 loc) · 1.82 KB
/
linear_probe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Mirror of the Linear Probe Evaluation Script
# from the official CLIP Repository.
import mlx.core as mx
import numpy as np
from image_processor import CLIPImageProcessor
from mlx.data.datasets import load_cifar10
from model import CLIPModel
from PIL import Image
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root).batch(batch_size)
test = load_cifar10(root=root, train=False).batch(batch_size)
return tr, test
def get_features(model, image_proc, iter):
all_features = []
all_labels = []
for batch in tqdm(iter):
image, label = batch["image"], batch["label"]
x = image_proc([Image.fromarray(im) for im in image])
y = mx.array(label)
image_embeds = model.get_image_features(x)
mx.eval(image_embeds)
all_features.append(image_embeds)
all_labels.append(y)
return mx.concatenate(all_features), mx.concatenate(all_labels)
if __name__ == "__main__":
model = CLIPModel.from_pretrained("mlx_model")
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
train_iter, test_iter = get_cifar10(batch_size=256)
train_features, train_labels = get_features(model, image_proc, train_iter)
test_features, test_labels = get_features(model, image_proc, test_iter)
# Perform logistic regression
# NOTE: The value of C should be determined via a hyperparameter sweep
# using a validation split
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
print(f"Accuracy = {accuracy:.3f}")