This repository has been archived by the owner on Oct 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathcomponents.py
159 lines (123 loc) · 6.25 KB
/
components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import warnings
warnings.simplefilter("ignore")
import logging # noqa: E402
import os # noqa: E402
from functools import partial # noqa: E402
from subprocess import Popen # noqa: E402
import gradio # noqa: E402
import torch # noqa: E402
import torchvision.transforms as T # noqa: E402
from lightning.app.components.python import TracerPythonScript # noqa: E402
from lightning.app.components.serve import ServeGradio # noqa: E402
from lightning.app.storage import Path # noqa: E402
from quick_start.download import download_data # noqa: E402
logger = logging.getLogger(__name__)
class PyTorchLightningScript(TracerPythonScript):
"""This component executes a PyTorch Lightning script and injects a callback in the Trainer at runtime in order to
start tensorboard server."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 1. Keep track of the best model path.
self.best_model_path = None
self.best_model_score = None
def configure_tracer(self):
# 1. Override `configure_tracer``
# 2. Import objects from lightning.pytorch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
# 3. Create a tracer.
tracer = super().configure_tracer()
# 4. Implement a callback to launch tensorboard server.
class TensorboardServerLauncher(Callback):
def __init__(self, work):
# The provided `work` is the current ``PyTorchLightningScript`` work.
self._work = work
def on_train_start(self, trainer, *_):
# Provide `host` and `port` in order for tensorboard to be usable in the cloud.
self._work._process = Popen(
f"tensorboard --logdir='{trainer.logger.log_dir}'"
f" --host {self._work.host} --port {self._work.port}",
shell=True,
)
def trainer_pre_fn(self, *args, work=None, **kwargs):
# Intercept Trainer __init__ call and inject a ``TensorboardServerLauncher`` component.
kwargs["callbacks"].append(TensorboardServerLauncher(work))
return {}, args, kwargs
# 5. Patch the `__init__` method of the Trainer to inject our callback with a reference to the work.
tracer.add_traced(Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self))
return tracer
def run(self, *args, **kwargs):
######### [DEMO PURPOSE] #########
# 1. Download a pre-trained model for speed reason.
download_data(
"https://pl-flash-data.s3.amazonaws.com/assets_lightning/demo_weights.pt",
"./",
)
# 2. Add some arguments to the Trainer to make training faster.
self.script_args += [
"--trainer.limit_train_batches=12",
"--trainer.limit_val_batches=4",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_loss",
]
# 3. Utilities
warnings.simplefilter("ignore")
logger.info(f"Running train_script: {self.script_path}")
######### [DEMO PURPOSE] #########
logger.info(f"Running train_script: {self.script_path}")
# 4. Execute the parent run method
super().run(*args, **kwargs)
def on_after_run(self, script_globals):
# 1. Once the script has finished to execute, we can collect its globals and access any objects.
# Here, we are accessing the LightningCLI and the associated lightning_module
lightning_module = script_globals["cli"].trainer.lightning_module
# 2. From the checkpoint_callback, we are accessing the best model weights
checkpoint = torch.load(script_globals["cli"].trainer.checkpoint_callback.best_model_path)
# 3. Load the best weights and torchscript the model.
lightning_module.load_state_dict(checkpoint["state_dict"])
lightning_module.to_torchscript("model_weight.pt")
# 4. Use lightning.app.storage.Path to create a reference to the torchscripted model
# When running in the cloud on multiple machines, by simply passing this reference to another work,
# it triggers automatically a transfer.
self.best_model_path = Path("model_weight.pt")
# 5. Keep track of the metrics.
self.best_model_score = float(script_globals["cli"].trainer.checkpoint_callback.best_model_score)
class ImageServeGradio(ServeGradio):
inputs = gradio.inputs.Image(type="pil", shape=(28, 28))
outputs = gradio.outputs.Label(num_top_classes=10)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.examples = None
self.best_model_path = None
self._transform = None
self._labels = {idx: str(idx) for idx in range(10)}
def run(self, best_model_path):
######### [DEMO PURPOSE] #########
# Download some examples so it works locally and in the cloud (issue with gradio on loading the images.)
download_data(
"https://pl-flash-data.s3.amazonaws.com/assets_lightning/images.tar.gz",
"./",
)
self.examples = [os.path.join("./images", f) for f in os.listdir("./images")]
######### [DEMO PURPOSE] #########
self.best_model_path = best_model_path
self._transform = T.Compose([T.Resize((28, 28)), T.ToTensor()])
super().run()
def predict(self, img):
with torch.inference_mode():
# 1. Receive an image and transform it into a tensor
img = self._transform(img)[0]
img = img.unsqueeze(0).unsqueeze(0)
# 2. Apply the model on the image and convert the logits into probabilities
prediction = torch.exp(self.model(img))
# 3. Return the data in the `gr.outputs.Label` format
return {self._labels[i]: prediction[0][i].item() for i in range(10)}
def build_model(self):
# 1. Load the best model. As torchscripted by the first component, using torch.load works out of the box.
model = torch.load(self.best_model_path)
# 2. Prepare the model for predictions.
for p in model.parameters():
p.requires_grad = False
model.eval()
# 3. Return the model.
return model