forked from hlinander/equivariant-posteriors
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
62 lines (54 loc) · 1.88 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
import torch
import torchmetrics as tm
from lib.train_dataclasses import TrainConfig
from lib.train_dataclasses import TrainEval
from lib.train_dataclasses import TrainRun
from lib.train_dataclasses import OptimizerConfig
from lib.train_dataclasses import ComputeConfig
from lib.metric import Metric
from lib.models.dense import DenseConfig
from lib.data_registry import DataSineConfig
from lib.train import load_or_create_state
from lib.train import do_training
from lib.ddp import ddp_setup
def main():
device_id = ddp_setup("gloo")
print(f"Using device {device_id}")
loss = torch.nn.MSELoss()
def mse_loss(outputs, batch):
return loss(outputs["logits"], batch["target"])
train_config = TrainConfig(
model_config=DenseConfig(d_hidden=100),
train_data_config=DataSineConfig(
input_shape=torch.Size([1]), output_shape=torch.Size([1])
),
val_data_config=DataSineConfig(
input_shape=torch.Size([1]), output_shape=torch.Size([1])
),
loss=mse_loss,
optimizer=OptimizerConfig(optimizer=torch.optim.Adam, kwargs=dict()),
batch_size=2,
)
train_eval = TrainEval(
train_metrics=[
lambda: Metric(tm.functional.mean_absolute_error),
lambda: Metric(tm.functional.mean_squared_error),
],
validation_metrics=[
lambda: Metric(tm.functional.mean_absolute_error),
lambda: Metric(tm.functional.mean_squared_error),
],
)
train_run = TrainRun(
compute_config=ComputeConfig(distributed=False, num_workers=1),
train_config=train_config,
train_eval=train_eval,
epochs=20,
save_nth_epoch=5,
validate_nth_epoch=5,
)
state = load_or_create_state(train_run, device_id)
do_training(train_run, state, device_id)
if __name__ == "__main__":
main()