-
Notifications
You must be signed in to change notification settings - Fork 13
/
deeponet_poisson.py
74 lines (61 loc) · 1.79 KB
/
deeponet_poisson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import deepxde as dde
import numpy as np
def get_data(
fname_train, fname_test, residual=False, stackbranch=False, stacktrunk=False
):
N = 500
# i = 0
# idx = np.arange(i * N, (i + 1) * N)
idx = np.random.choice(100000, size=N, replace=False)
d = np.load(fname_train)
X_branch = d["X0"][idx]
X_trunk = d["X1"][idx]
if stackbranch:
X_branch = np.hstack((d["X0"][idx], d["y_low"][idx]))
if stacktrunk:
X_trunk = np.hstack((d["X1"][idx], d["y_low_x"][idx]))
X_train = (X_branch, X_trunk)
y_train = d["y"][idx]
if residual:
y_train -= d["y_low_x"][idx]
d = np.load(fname_test)
X_branch = d["X0"]
X_trunk = d["X1"]
if stackbranch:
X_branch = np.hstack((d["X0"], d["y_low"]))
if stacktrunk:
X_trunk = np.hstack((d["X1"], d["y_low_x"]))
X_test = (X_branch, X_trunk)
y_test = d["y"]
if residual:
y_test -= d["y_low_x"]
return X_train, y_train, X_test, y_test
def run(data, net, lr, epochs):
model = dde.Model(data, net)
model.compile("adam", lr=lr)
losshistory, train_state = model.train(epochs=epochs)
dde.saveplot(losshistory, train_state, issave=False, isplot=True)
def main():
fname_train = "../data/train.npz"
fname_test = "../data/test.npz"
X_train, y_train, X_test, y_test = get_data(
fname_train, fname_test, residual=True, stackbranch=False, stacktrunk=False
)
data = dde.data.OpDataSet(
X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
m = 100
dim_x = 1
width = 5
net = dde.maps.OpNN(
[m, width, width],
[dim_x, width],
"selu",
"LeCun normal",
)
lr = 0.0001
epochs = 50000
run(data, net, lr, epochs)
if __name__ == "__main__":
main()