Physics-informed neural networks (PINNs) were introduced by Raissi, Perdikaris and Karniadakis in 2019, as a method of finding numerical solutions to continuous and discrete-time partial differential equations, as well as parameterising those equations using data. In this repository, I concentrate on
- finding a numerical solution to the continuous-time advection equation (the "forward" problem);
- parameterising the advection equation using observations (the "inverse" problem). There is a challenge for readers to solve!
No doubt these have been done my many other authors, but I'm trying to teach myself!
Towards the end of this page, I outline another feature of Raissi, Perdikaris and Karniadakis's work, namely discrete time integration via Runge-Kutta. This appears to be startling novel and important, but I only summarise it here.
Python with TensorFlow is used to solve the PINNs in this repository. To run the python scripts, matplotlib
and tensorflow
are needed. Install them using, for instance,
conda install -c conda-forge tensorflow
conda install -c conda-forge matplotlib
The approach assumes that the dynamics of a system are described by a differential equation
Here,
For instance, in the heat equation case, a uniform
The PINN approach builds a standard neural network that outputs
Here the
These derivatives are calculated using automatic differentiation. It is easy to differentiate a single activation function with respect to its input. Automatic differentiation is just doing that derivative, and using heaps of chain rules to propagate the differential operator through the entire neural network. So, automatic differentiation can easily give
An alternate view of this is that the (automatic) differentiation of the original neural network gives another neural network: one that outputs
The remainder of the PINN approach appears to be just "gloss". This "gloss" may critically influence convergence in many cases, so may be practically vital, but it is not the focus of this page. For instance, Raissi, Perdikaris and Karniadakis use automatic differentiation to find the derivatives of the neural network with respect to its internal parameters (biases and weights) and hence use a Newton method to converge. However, this "gloss" is not the focus of this page: the critical point is that the loss function is unusual.
As a primer for the advection problem, consider building a neural network to find
This means the neural network is integrating the function
-
$\mathrm{d}u/\mathrm{d}x = f$ is equivalent to the PDE -
$u(0) = u_{0}$ is equivalent to the boundary and initial conditions.
This problem is different than usual nonlinear regression using a neural network. In the usual case, the values,
The code to perform the integration is in integrate.py.
The loss function is a linear combination of the conditions x
are the points on the domain at which the function_values
(
def loss_de(x, function_values):
''' Returns sum_over_x(|du/dx - function_values|^2) / number_of_x_points
'''
# First, use TensorFlow automatic differentiation to evaluate du/dx, at the points x, where u is given by the NN model
with tf.GradientTape(persistent = True) as tp:
tp.watch(x)
u = model(x) # "model" is the NN's output value, given x
u_x = tp.gradient(u, x)
del tp
# The loss is just the mean-squared du/dx - function_values
return tf.reduce_mean(tf.square(u_x - function_values))
The constraint x
are the points on the domain, recall
def loss_bdy(x, val_at_zero):
''' Evaluate the boundary condition, ie u(0) - val_at_zero, where u is given by the NN model
'''
val = model(tf.convert_to_tensor([x[0]], dtype = tf.float32)) - val_at_zero # "model" is the NN predicted value, given x
return tf.reduce_mean(tf.square(val))
Finally, the loss used in the neural-network training process is a weighted linear combination of these (in this code snippet, X
are the points on the domain at which the fcn
values (value_at_0
is
@tf.function # decorate for speed
def loss(ytrue, ypred):
''' The loss used by the training algorithm. Note that ytrue and ypred are not used,
but TensorFlow specifies these arguments
'''
bdy_weight = 1
de_weight = 1
return bdy_weight * loss_bdy(X, value_at_0) + de_weight * loss_de(X, fcn)
The following figure shows the result for
The advection equation in 1 spatial dimension is
Here:
-
$t$ is time -
$x$ is the spatial coordinate -
$u = u(t, x)$ is the thing we're trying to find. It could represent temperature of a fluid, or concentration of a pollutant, for instance -
$v$ is the advection velocity. It could be the velocity of the fluid, for instance. Assume it is independent of$x$ and$t$ .
The analytical solution of the advection equation is
where
The reason for choosing the advection equation is that its analytical solution is known and straightfoward, and that the advection equation is quite tricky to solve using standard numerical techniques. In naive schemes, the numerical solution suffers from over-shoots and under-shoots. For instance, if the initial condition satisfies
then at later times the numerical solution breaks these bounds (
Assume the spatial domain is bounded:
Here
Assume the boundary conditions are Dirichlet (fixed value) on the left:
and Neumann (fixed flux, which is zero in this case) on the right:
Assume the velocity,
The PINN must have two inputs,
depth = 5 # depth of NN
width = 10 # width of fully-connected NN
activation = 'relu' # alternatives 'selu', 'softplus', 'sigmoid', 'tanh', 'elu', 'relu'
epochs = 1000 # training epochs
batch_size = 1000 # batch size
model = Sequential()
model.add(Dense(width, kernel_initializer = 'normal', activation = activation, input_shape = (2, )))
for d in range(1, depth):
model.add(Dense(width, kernel_initializer = 'normal', activation = activation))
model.add(Dense(1, kernel_initializer = 'normal', activation = activation))
In order to evaluate the loss function, points on the space-time boundary as well as its interior need to be defined. The initial conditions, boundary conditions and the advection equation are evaluated (using the neural network) at these points. In the following code snipped the front
function is the cubic defined above.
num_initial = 10000
num_dirichlet = 101
num_neumann = 102
num_interior = 10000
T_initial = tf.constant([0] * num_initial, dtype = tf.float32)
X_initial = tf.random.uniform(shape = [num_initial], minval = -1, maxval = 1, dtype = tf.float32)
vals_initial = front(X_initial)
T_dirichlet = tf.random.uniform(shape = [num_dirichlet], minval = 0, maxval = 1, dtype = tf.float32)
X_dirichlet = tf.constant([-1] * num_dirichlet, dtype = tf.float32)
vals_dirichlet = tf.constant([1] * num_dirichlet, dtype = tf.float32)
T_neumann = tf.random.uniform(shape = [num_neumann], minval = 0, maxval = 1, dtype = tf.float32)
X_neumann = tf.constant([1] * num_neumann, dtype = tf.float32)
flux_neumann = tf.constant([0] * num_neumann, dtype = tf.float32)
T_interior = tf.random.uniform(shape = [num_interior], minval = 0, maxval = 1, dtype = tf.float32)
X_interior = tf.random.uniform(shape = [num_interior], minval = -1, maxval = 1, dtype = tf.float32)
The loss function is a weighted sum of four terms. The first two are the initial condition (evaluated at (T_initial, X_initial)
) and the Dirichlet boundary condition (evaluated at (T_dirichlet, X_dirichlet)
), as implemented by:
def loss_dirichlet(t, x, u_desired):
''' Evaluate the initial condition or Dirichlet boundary condition (both are "fixed u" conditions), ie
sum_over(t, x)(|u - u_desired|^2) / number_of_(t, x)_points, where u is given by the NN model
'''
u_vals = tf.reshape(model(tf.stack([t, x], 1)), [len(t)]) # "model" is the NN predicted value, given (t, x)
return tf.reduce_mean(tf.square(u_vals - u_desired))
The third is the Neumann condition (evaluated at (T_neumann, X_neumann)
):
def loss_neumann(t, x, flux_desired):
''' Evaluate the Neumann boundary condition, ie
sum_over(t, x)(|du/dx - flux_desired|^2) / number_of_(t, x)_points, where u is given by the NN model
'''
# First, use TensorFlow automatic differentiation to evaluate du/dx, at the points (t, x), where u is given by the NN model
with tf.GradientTape(persistent = True) as tp:
tp.watch(x)
u = model(tf.stack([t, x], 1)) # "model" is the NN predicted value, u, given (t, x)
u_x = tp.gradient(u, x)
del tp
# Now return the loss
return tf.reduce_mean(tf.square(u_x - flux_desired))
The fourth term is the advection equation (evaluated at (T_interior, X_interior)
):
def loss_de(t, x):
''' Returns sum_over_(t, x)(|du/dt + velocity * du/dx|^2) / number_of_(t, x)_points
'''
# First, use TensorFlow automatic differentiation to evaluate du/dt and du/dx, at the points (t, x), where u is given by the NN model
with tf.GradientTape(persistent = True) as tp:
tp.watch(t)
tp.watch(x)
u = model(tf.stack([t, x], 1)) # "model" is the NN predicted value, u, given (t, x)
u_t = tp.gradient(u, t)
u_x = tp.gradient(u, x)
del tp
# The loss is just the mean-squared du/dt + velocity * du/dx
return tf.reduce_mean(tf.square(u_t + velocity * u_x))
These four terms contribute to the overall loss function:
@tf.function # decorate for speed
def loss(ytrue, ypred):
''' The loss used by the training algorithm. Note that ytrue and ypred are not used,
but TensorFlow specifies these arguments
'''
weight_initial = 1
weight_dirichlet = 1
weight_neumann = 1
weight_de = 1
return weight_initial * loss_dirichlet(T_initial, X_initial, vals_initial) + weight_dirichlet * loss_dirichlet(T_dirichlet, X_dirichlet, vals_dirichlet) + weight_neumann * loss_neumann(T_neumann, X_neumann, flux_neumann) + weight_de * loss_de(T_interior, X_interior)
Now it is simply a matter of compiling the neural network specifying this loss function, and training in the usual way:
model.compile(loss = loss, optimizer = 'adam') # note the specification of the loss function
# Note that the dummy values are not used in the training because the loss function uses different arguments.
dummy_tx = [[0, 0]]
dummy_true = [0]
history = model.fit(dummy_tx, dummy_true, epochs = epochs, batch_size = batch_size, verbose = 1)
The results are pleasing!
- Obviously, the boundary conditions need not be at the boundary, and these could be experimental measurements of some quantity. This is related to parameter identification
- Is the PINN approach any better (computationally cheaper, more accurate, more robust) than: (a) finite-element; (b) using something else like a bunch of radial basis functions?
- In some cases, an explicit time-stepping approach is advantageous to ensure robustness. Would the PINN approach that is solving "everything, everywhere, all at once" actually work? This is related to discrete time integration
- Could adaptive positioning of the points reduce computational time and/or increase robustness and/or increase accuracy?
- How is something like mass conservation implemented?
- How about known fluxes (from an injector inside the domain, or from the boundary) - can randomly-placed points really ensure the correct total flux?
- Could PINNs be re-jigged to do explicit time-stepping? Yes!
When the DE contains unknown parameters,
- In groundwater models, the speed of groundwater flow is governed by a parameter called "hydraulic conductivity" (
$\lambda$ is hydraulic conductivity). For instance, a blob of sandstone usually has high conductivity, while a blob of clay usually has low conductivity. Water can pass rapidly through the sand, but moves only very slowly through the clay. Imagine the situation where groundwater head ($u$ ) has been measured at various times and places throughout a body of rock that contains a patchwork of sandstone and clay blobs. The inverse problem is to use those measurements to infer the spatial distribution of hydraulic conductivity, that is, to determine where the sandstone and clay blobs are. - Populations of mosquitoes are often prescribed a "carrying capacity" (
$\lambda$ is carrying capacity), which is the number of mosquitoes when things are in equilibrium (eg, after many generations of breeding and dying). If the carrying capacity is large, then the mosquito population can grow large, while if the carrying capacity is small, the mosquito population is limited to a small number. Carrying capacity varies in time. For instance, in wet summers in Africa, the carrying capacity is large, while in dry winters it is low. Carrying capacity also varies in space. For instance, near to population centres with plentiful standing water, the carrying capacity is large, while in the desert, it is small. The population is dynamic and rarely (or never) achieves it's theoretical equilibrium at the carrying capacity. Imagine the situation where mosquito numbers ($u$ ) have been measured at various times and places throughout a landscape. The inverse problem is to find the carrying capacity as a function of time and space given those observations.
In both these problems, the underyling DE is assumed to be known, but it contains one or more unknown parameters,
To solve the inverse problem, further terms are added to the loss function:
Here,
The neural network is trained by finding the optimal paramters in this problem, which are the neural network weights and biases and the
In summary, the procedure described below will:
- find the unknown advection velocity, given some observational data
- solve the advection equation to find
$u = u(t, x)$ .
Because the solution of the advection equation is known, as described above, observational data can be easily generated. The code is observations.py. No noise is added in this example:
velocity = 1.2 # the true velocity, that will be discovered, hopefully, by the inverse process
def front(x):
window = 0.2
return tf.where(x <= -1, 1, tf.where(x >= -1 + 2 * window, 0, 0.5 + (x + 1 - window) * (tf.pow(x + 1 - window, 2) - 3 * tf.pow(window, 2)) / 4 / tf.pow(window, 3)))
num_points = 10000
X = tf.random.uniform(shape = [num_points], minval = -1, maxval = 1, dtype = tf.float32).numpy()
T = tf.random.uniform(shape = [num_points], minval = 0, maxval = 1, dtype = tf.float32).numpy()
u = front(X - velocity * T).numpy() # the solution to the advection equation
with open("observations.csv", "w") as f:
f.write("#true velocity = " + str(velocity) + "\n")
f.write("T,X,u\n")
for pt in range(num_points):
f.write(str(T[pt]) + "," + str(X[pt]) + "," + str(u[pt]) + "\n")
The code may be found in inverse.py. It is actually very similar to advection.py, but to minimise the loss, the derivatives with respect to the unknown velocity are needed (so that gradient descent can be used to modify the velocity (and neural net weights and biases) to minimise the loss). This necessitates explicitly defining the weights and biases and the neural network architecture, as well as the gradient-descent. Hence, things like
model = Sequential()
model.add(Dense(...))
model.fit(...)
expand to many lines of code.
First, velocity
is now an unknown variable, so it is defined and initialised using
velocity = tf.Variable(0.5, dtype = tf.float32) # advection velocity with initial guess = 0.5
The observations are read from file:
obs = pd.read_csv("observations.csv", comment = "#")
T_obs = tf.constant(obs['T'], dtype = tf.float32)
X_obs = tf.constant(obs['X'], dtype = tf.float32)
vals_obs = tf.constant(obs['u'], dtype = tf.float32)
The various loss functions are the same as for the advection.py case, with loss_dirichlet
being used to ensure the neural network output matches the observations:
def loss_dirichlet(t, x, u_desired):
''' Evaluate the observations or Dirichlet boundary condition (both are "fixed u" conditions), ie
sum_over(t, x)(|u - u_desired|^2) / number_of_(t, x)_points, where u is given by the NN model
'''
u_vals = tf.reshape(model(tf.stack([t, x], 1)), [len(t)]) # "model" is the NN predicted value, given (t, x)
return tf.reduce_mean(tf.square(u_vals - u_desired))
The following rather large block of code defines the weights and biases as a whole lot of tf.Variables. This is in contrast to advection.py, where their definition was hidden inside the lines model = Sequential()
and model.add(Dense(...))
. The weights are Glorot initialised and the biases are initialised to zero. The reason the weights and biases are explicitly defined is that later they'll appear in the explicit definition of the neural net. They are also collected into the "params" list, which will appear in
depth = 5 # depth of NN
width = 10 # width of fully-connected NN
def glorot_init_weight(in_dim, out_dim):
return tf.Variable(tf.random.truncated_normal(shape = [in_dim, out_dim], mean = 0.0, stddev = np.sqrt(2.0 / (in_dim + out_dim)), dtype = tf.float32), dtype = tf.float32)
def zero_init_bias(in_dim, out_dim):
return tf.Variable(tf.zeros(shape = [in_dim, out_dim], dtype = tf.float32), dtype = tf.float32)
weights = [] # all the weights
biases = [] # all the biases
params = [] # all the weights and biases AND THE VELOCITY!
# 2 inputs = (t, x)
w = glorot_init_weight(2, width)
b = zero_init_bias(1, width)
weights.append(w)
biases.append(b)
params.append(w)
params.append(b)
for d in range(1, depth):
w = glorot_init_weight(width, width)
b = zero_init_bias(1, width)
weights.append(w)
biases.append(b)
params.append(w)
params.append(b)
# 1 output = u
w = glorot_init_weight(width, 1)
b = zero_init_bias(1, 1)
weights.append(w)
biases.append(b)
params.append(w)
params.append(b)
Now for a very important line, which is the whole reason for defining everything so explicitly:
params.append(velocity)
This means that the automatic differentiation used in
The neural network can now be defined. It explicitly uses the weights and biases that were just defined, so that
def model(x):
# This uses the relu activation function. relu(x) = max(x, 0)
# Some other alternatives like tanh, relu and softplus could be used
z = x
for d in range(depth):
w = weights[d]
b = biases[d]
zp = tf.add(tf.matmul(z, w), b)
#z = tf.math.log(1 + tf.math.exp(zp)) # softplus activation
#z = tf.math.tanh(zp) # tanh activation
#z = tf.where(zp > 0, zp, tf.math.exp(zp) - 1) # elu activation
z = tf.math.maximum(zp, 0) # relu activation
w = weights[depth]
b = biases[depth]
zp = tf.add(tf.matmul(z, w), b)
return tf.math.maximum(zp, 0) # might like to change this activation function if the hidden layers are also changed
The following piece of code does one step of the gradient descent. Note the appearance of
optimizer = tf.keras.optimizers.Adam()
@tf.function # decorate for speed
def gradient_descent():
with tf.GradientTape(persistent = True) as tp:
epoch_loss = loss()
gradient = tp.gradient(epoch_loss, params) # d(loss)/d(params)
del tp
# because params includes weights, biases and velocity, the following line
# alters all of these to reduce the loss
optimizer.apply_gradients(zip(gradient, params))
return epoch_loss
The gradient descent is now simple:
epochs = 1000 # training epochs
for epoch in range(epochs):
epoch_loss = gradient_descent()
print("epoch =", epoch, "loss =", epoch_loss.numpy(), "velocity =", velocity.numpy())
After 1000 training epochs, the algorithm has predicted that velocity = 1.19 (the true value is 1.2) and it has produced a neural network that predicts
Another set of observational data is found in challenge.csv. In this set, there are 2 unknown advection velocities:
- the first velocity applies when
$t < 0.5$ - the second velocity applies when
$t \geq 0.5$
Can you find these velocities?
You might like to approach this challenge by modifying inverse.py. First, try to run inverse.py
: if you can't then check the required packages. The code inverse.py can be modified in the following ways:
- Define a new
tf.Variable
calledvelocity_end
that will be active for$t\geq 0.5$ . That is, add another line similar tovelocity = tf.Variable(0.5, dtype = tf.float32)
. - Instead of reading the observations from
observations.csv
, usechallenge.csv
instead. That is, change the lineobs = pd.read_csv("observations.csv", comment = "#")
. - Change the PDE, specifically the PDE loss, to include
velocity_end
instead of justvelocity
. That is, change the snippetu_t + velocity * u_x
by usingtf.where(t < 0.5, ...)
instead of justvelocity
. - Append
velocity_end
to theparams
list in a similar way toparams.append(velocity)
. - At the end of each training iteration, print both
velocity
andvelocity_end
. That is, add to the lineprint("epoch =", epoch, "loss =", epoch_loss.numpy()...
.
Now you should be able to run the code. You'll probably find pretty poor results: look at the loss, particularly the observation loss, to see the algorithm hasn't converged particularly well. Remember, the animation that is produced shows the PINN result correctly, but the "True" result (black curve) is from the previous case with just one velocity, so can be ignored. To improve things, you might like to:
- try more training epochs, for instance
epochs = 10000
. Oh dear, it seems as if the convergence stalls atloss = 0.05
- try a different activation function, for instance
tanh
(see themodel(x)
function). Aha, after 10000 epochs, the loss has reduced to 0.003 when usingtanh
, which is pleasing. - try weighting the terms in the loss function differently (eg
weight_de
) - try changing the parameters within the optimizer (the
Adam
parameters) - try using a different weight and bias initialisation
- experiment with your own ideas!
What are your best values for velocity
and velocity_end
? Hint, the true value of velocity
is positive, while velocity_end
is negative.
Suppose that the dynamics of a system are described by the differential equation
Here
The usual fourth-order explicit Runge-Kutta scheme for integrating this is as follows. Given
Then
In these expressions, the coefficients
Here
The simple, yet revolutionary idea of Raissi, Perdikaris and Karniadakis is to build a neural network that has input
In this formula, the spatio-temporal dependencies have been mostly supressed for clarity of exposition, and the "mean" runs over all spatial points that are chosen randomly within the domain, as in the advection-equation example. The "bdy" term in the Loss function are from boundary conditions. Because the outputs of the neural network,
The rather dramatic consequence of this scheme is that the number of Runge-Kutta stages,