Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon 2] 94. Add Paddle as a new backend of DeepXDE #562

Merged
merged 53 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d710791
add paddlepaddle support
AndPuQing Mar 8, 2022
38fc0a0
add paddle support
AndPuQing Mar 8, 2022
dbcf6f4
fix optimizer error
AndPuQing Mar 9, 2022
9d04820
add pinn ode_system support
AndPuQing Mar 9, 2022
b75d8c0
fix init error
AndPuQing Mar 9, 2022
e2abd4b
fix training error
AndPuQing Mar 9, 2022
03f881a
fix grad error
AndPuQing Mar 9, 2022
f20369c
fix ndim error
AndPuQing Mar 10, 2022
9f844f3
test
AndPuQing Mar 10, 2022
6cdeeb3
add save and load support
AndPuQing Mar 10, 2022
cfa4750
fix formatting
AndPuQing Mar 10, 2022
d4f1907
add docs
AndPuQing Mar 10, 2022
fa0c3a2
Merge branch 'lululxvi:master' into master
Mar 10, 2022
e517370
fix write mistake
AndPuQing Mar 10, 2022
77ba91f
Merge branch 'master' of github.com:AndPuQing/deepxde
AndPuQing Mar 10, 2022
1f00abb
change paddlepaddle to paddle
AndPuQing Mar 11, 2022
c249502
add paddle navigation
AndPuQing Mar 11, 2022
49d267d
merge
AndPuQing Mar 12, 2022
7dcdc17
Merge branch 'lululxvi-master'
AndPuQing Mar 12, 2022
e1a8b9f
merge
AndPuQing Mar 14, 2022
f7f9d70
add aux support
AndPuQing Mar 14, 2022
080ca76
Merge branch 'master'
AndPuQing Mar 14, 2022
faed64e
Merge branch 'lululxvi-master'
AndPuQing Mar 14, 2022
55c9033
fix compile_paddle error
AndPuQing Mar 15, 2022
c1be75d
Merge branch 'master
AndPuQing Apr 3, 2022
5477584
Merge branch 'lululxvi-master'
AndPuQing Apr 3, 2022
2f1b0a9
Merge branch 'lululxvi:master' into master
AndPuQing Apr 29, 2022
49b72ef
fix grad error
AndPuQing May 10, 2022
b165f99
Merge branch 'master' of github.com:AndPuQing/deepxde
AndPuQing May 10, 2022
ea4a8ef
Merge branch 'master'
AndPuQing May 10, 2022
631dc6f
Merge branch 'lululxvi-master'
AndPuQing May 10, 2022
26173ca
fix contradiction
AndPuQing May 10, 2022
04f2949
fix merge error
AndPuQing May 11, 2022
b28e052
Merge branch 'master'
AndPuQing May 12, 2022
678674b
Merge branch 'lululxvi-master'
AndPuQing May 12, 2022
285e102
fix the specification
AndPuQing May 12, 2022
4a2c470
Merge branch 'lululxvi:master' into master
AndPuQing May 12, 2022
26f2195
fix grad and pde update
AndPuQing May 12, 2022
29f65a8
Merge branch 'master' of github.com:AndPuQing/deepxde
AndPuQing May 12, 2022
15000c4
fix train_step error
AndPuQing May 12, 2022
c944f2d
fix specification problem
AndPuQing May 12, 2022
6dabe4e
cast int
AndPuQing May 16, 2022
6deb3d2
Merge branch 'master'
AndPuQing May 16, 2022
a2507f9
Merge branch 'lululxvi-master'
AndPuQing May 16, 2022
665930d
fix paddle ndim and shape return
AndPuQing May 16, 2022
fb441ea
fix paddle develop require
AndPuQing May 16, 2022
3615f8a
Merge branch 'lululxvi:master' into master
AndPuQing May 17, 2022
ef1e499
update
AndPuQing May 17, 2022
c1486ad
remove .numpy
AndPuQing May 17, 2022
546f316
update
AndPuQing May 17, 2022
dc86a59
Merge branch 'lululxvi:master' into master
AndPuQing May 18, 2022
5ea5ecd
fix format
AndPuQing May 18, 2022
74cc4cc
paddle.autograd.grad -> paddle.grad
AndPuQing May 19, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ docs/_build/
.ipynb_checkpoints

# VSCode
.vscode/
.vscode/
4 changes: 2 additions & 2 deletions deepxde/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _missing_api(*args, **kwargs):


def load_backend(mod_name):
if mod_name not in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"]:
if mod_name not in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"]:
raise NotImplementedError("Unsupported backend: %s" % mod_name)

print("Using backend: %s\n" % mod_name, file=sys.stderr, flush=True)
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_preferred_backend():
config_dict = json.load(config_file)
backend_name = config_dict.get("backend", "").lower()

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"]:
return backend_name
print(
"Deepxde backend not selected or invalid. Assuming tensorflow.compat.v1 for now.",
Expand Down
1 change: 1 addition & 0 deletions deepxde/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
tf = None
torch = None
jax = None
paddle = None
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
# Tensor, data type and context interfaces
Expand Down
1 change: 1 addition & 0 deletions deepxde/backend/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tensor import * # pylint: disable=redefined-builtin
110 changes: 110 additions & 0 deletions deepxde/backend/paddle/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""paddle backend implementation"""
import paddle


if paddle.device.is_compiled_with_cuda():
paddle.device.set_device("gpu")

lib = paddle


def data_type_dict():
return {
"float16": paddle.float16,
"float32": paddle.float32,
"float64": paddle.float64,
"uint8": paddle.uint8,
"int8": paddle.int8,
"int16": paddle.int16,
"int32": paddle.int32,
"int64": paddle.int64,
"bool": paddle.bool,
}


def is_tensor(obj):
return paddle.is_tensor(obj)


def shape(input_tensor):
return input_tensor.shape


def ndim(input_tensor):
return input_tensor.ndim


def Variable(initial_value, dtype=None):
return paddle.to_tensor(initial_value, dtype=dtype, stop_gradient=False)


def as_tensor(data, dtype=None):
if isinstance(data, paddle.Tensor) or paddle.is_tensor(data):
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
if dtype is None or data.dtype == dtype:
return data
return data.astype(dtype)
return paddle.to_tensor(data, dtype=dtype)


def from_numpy(np_array):
return paddle.to_tensor(np_array)


def to_numpy(input_tensor):
return input_tensor.detach().cpu().numpy()


def elu(x):
return paddle.nn.functional.elu(x)


def relu(x):
return paddle.nn.functional.relu(x)


def selu(x):
return paddle.nn.functional.selu(x)


def sigmoid(x):
return paddle.nn.functional.sigmoid(x)


def silu(x):
return paddle.nn.functional.silu(x)


def sin(x):
return paddle.sin(x)


def square(x):
return paddle.square(x)


def tanh(x):
return paddle.tanh(x)


def mean(input_tensor, dim, keepdims=False):
return paddle.mean(input_tensor, axis=dim, keepdim=keepdims)


def reduce_mean(input_tensor):
return paddle.mean(input_tensor)


def sum(input_tensor, dim, keepdims=False):
return paddle.sum(input_tensor, axis=dim, keepdim=keepdims)


def reduce_sum(input_tensor):
return paddle.sum(input_tensor)


def zeros(shape, dtype):
return paddle.zeros(shape, dtype=dtype)


def zeros_like(input_tensor):
return paddle.zeros_like(input_tensor)
4 changes: 2 additions & 2 deletions deepxde/backend/set_default_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def set_default_backend(backend_name):
print(
'Setting the default backend to "{}". You can change it in the '
"~/.deepxde/config.json file or export the DDEBACKEND environment variable. "
"Valid options are: tensorflow.compat.v1, tensorflow, pytorch, jax (all lowercase)".format(
"Valid options are: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle (all lowercase)".format(
backend_name
)
)
Expand All @@ -25,7 +25,7 @@ def set_default_backend(backend_name):
"backend",
nargs=1,
type=str,
choices=["tensorflow.compat.v1", "tensorflow", "pytorch", "jax"],
choices=["tensorflow.compat.v1", "tensorflow", "pytorch", "jax", "paddle"],
help="Set default backend",
)
args = parser.parse_args()
Expand Down
2 changes: 2 additions & 0 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def on_train_begin(self):
self.value = [var.numpy() for var in self.var_list]
elif backend_name == "pytorch":
self.value = [var.detach().item() for var in self.var_list]
elif backend_name == "paddle":
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
self.value = [var.detach().item() for var in self.var_list]
print(
self.model.train_state.epoch,
list_to_str(self.value, precision=self.precision),
Expand Down
4 changes: 3 additions & 1 deletion deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .backend import backend_name, tf, torch
from .backend import backend_name, tf, torch, paddle
from .real import Real

random_seed = None
Expand Down Expand Up @@ -57,6 +57,8 @@ def set_random_seed(seed):
elif backend_name == "tensorflow":
tf.random.set_seed(seed) # tf CPU seed
os.environ["TF_DETERMINISTIC_OPS"] = "1"
elif backend_name == "paddle":
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
paddle.seed(seed)
elif backend_name == "pytorch":
torch.manual_seed(seed)
elif backend_name == "jax":
Expand Down
3 changes: 2 additions & 1 deletion deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
outputs_pde = outputs
elif backend_name == "jax":
# JAX requires pure functions
Expand All @@ -152,6 +152,7 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
)

bcs_start = np.cumsum([0] + self.num_bcs)
bcs_start = list(map(int, bcs_start))
error_f = [fi[bcs_start[-1] :] for fi in f]
losses = [
loss_fn[i](bkd.zeros_like(error), error) for i, error in enumerate(error_f)
Expand Down
15 changes: 9 additions & 6 deletions deepxde/gradients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["jacobian", "hessian"]

from .backend import backend_name, tf, torch, jax
from .backend import backend_name, tf, torch, jax, paddle


class Jacobian:
Expand All @@ -18,7 +18,7 @@ def __init__(self, ys, xs):
self.ys = ys
self.xs = xs

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
self.dim_y = ys.shape[1]
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
Expand Down Expand Up @@ -50,6 +50,9 @@ def __call__(self, i=0, j=None):
self.J[i] = torch.autograd.grad(
y, self.xs, grad_outputs=torch.ones_like(y), create_graph=True
)[0]
elif backend_name == "paddle":
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
self.J[i] = paddle.autograd.grad(y, self.xs, create_graph=True)[0]
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
elif backend_name == "jax":
# Here, we use jax.grad to compute the gradient of a function. This is
# different from TensorFlow and PyTorch that the input of a function is
Expand All @@ -67,7 +70,7 @@ def __call__(self, i=0, j=None):
grad_fn = jax.grad(lambda x: self.ys[1](x)[i])
self.J[i] = (jax.vmap(grad_fn)(self.xs), grad_fn)

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
return (
self.J[i] if j is None or self.dim_x == 1 else self.J[i][:, j : j + 1]
)
Expand Down Expand Up @@ -141,7 +144,7 @@ def __call__(self, ys, xs, i=0, j=None):
# f(x)
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (ys.ref(), xs.ref())
elif backend_name == "pytorch":
elif backend_name in ["pytorch", "paddle"]:
key = (ys, xs)
elif backend_name == "jax":
key = (id(ys[0]), id(xs))
Expand Down Expand Up @@ -197,7 +200,7 @@ class Hessian:
"""

def __init__(self, y, xs, component=None, grad_y=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
dim_y = y.shape[1]
elif backend_name == "jax":
dim_y = y[0].shape[0]
Expand Down Expand Up @@ -239,7 +242,7 @@ def __init__(self):
def __call__(self, y, xs, component=None, i=0, j=0, grad_y=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (y.ref(), xs.ref(), component)
elif backend_name == "pytorch":
elif backend_name in ["pytorch", "paddle"]:
key = (y, xs, component)
elif backend_name == "jax":
key = (id(y[0]), id(xs), component)
Expand Down
2 changes: 1 addition & 1 deletion deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def wrapper_cache_auxiliary(X, beg, end, aux_var):
return wrapper_nocache
if utils.get_num_args(func) == 2:
return wrapper_nocache_auxiliary
if backend_name == "pytorch":
if backend_name in ["paddle", "pytorch"]:
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
if utils.get_num_args(func) == 1:
return wrapper_cache
if utils.get_num_args(func) == 2:
Expand Down
Loading