Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove old style forms #457

Merged
merged 2 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/examples/ex22.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
from skfem import *
from skfem.models.poisson import laplace
from skfem.helpers import grad
import numpy as np

m = MeshTri.init_lshaped()
Expand All @@ -32,27 +33,28 @@ def eval_estimator(m, u):
# interior residual
basis = InteriorBasis(m, e)

@functional
@Functional
def interior_residual(w):
h = w.h
x, y = w.x
return h**2 * load_func(x, y)**2
return h ** 2 * load_func(x, y) ** 2

eta_K = interior_residual.elemental(basis, w=basis.interpolate(u))

# facet jump
fbasis = [FacetBasis(m, e, side=i) for i in [0, 1]]
w = [fbasis[i].interpolate(u) for i in [0, 1]]
w = {'u' + str(i + 1): fbasis[i].interpolate(u) for i in [0, 1]}

@functional
@Functional
def edge_jump(w):
h = w.h
n = w.n
du1, du2 = w.dw
return h * ((du1[0] - du2[0])*n[0] +\
(du1[1] - du2[1])*n[1])**2
dw1 = grad(w['u1'])
dw2 = grad(w['u2'])
return h * ((dw1[0] - dw2[0]) * n[0] +\
(dw1[1] - dw2[1]) * n[1]) ** 2

eta_E = edge_jump.elemental(fbasis[0], w=w)
eta_E = edge_jump.elemental(fbasis[0], **w)

tmp = np.zeros(m.facets.shape[1])
np.add.at(tmp, fbasis[0].find, eta_E)
Expand Down
8 changes: 2 additions & 6 deletions skfem/assembly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@

from .basis import Basis, InteriorBasis, FacetBasis
from .dofs import Dofs, DofsView
from .form import Form, BilinearForm, LinearForm, Functional,\
bilinear_form, linear_form, functional
from .form import Form, BilinearForm, LinearForm, Functional


def asm(form: Form,
Expand All @@ -78,7 +77,4 @@ def asm(form: Form,
"DofsView",
"BilinearForm",
"LinearForm",
"Functional",
"bilinear_form",
"linear_form",
"functional"]
"Functional"]
6 changes: 3 additions & 3 deletions skfem/assembly/form/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .form import Form # noqa
from .bilinear_form import bilinear_form, BilinearForm # noqa
from .linear_form import linear_form, LinearForm # noqa
from .functional import functional, Functional # noqa
from .bilinear_form import BilinearForm # noqa
from .linear_form import LinearForm # noqa
from .functional import Functional # noqa
32 changes: 1 addition & 31 deletions skfem/assembly/form/bilinear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Any
from typing import Optional, Any

import numpy as np

Expand Down Expand Up @@ -74,33 +74,3 @@ def assemble(self,

def _kernel(self, u, v, w, dx):
return np.sum(self.form(*u, *v, w) * dx, axis=1)


def bilinear_form(form: Callable) -> BilinearForm:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @bilinear_form wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@BilinearForm.", DeprecationWarning)

class ClassicBilinearForm(BilinearForm):

def _kernel(self, u, v, w, dx):
u = u[0]
v = v[0]
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
if u.ddf is not None:
return np.sum(self.form(u=u.f, du=u.df, ddu=u.ddf,
v=v.f, dv=v.df, ddv=v.ddf,
w=FormParameters(**W)) * dx, axis=1)
else:
return np.sum(self.form(u=u.f, du=u.df,
v=v.f, dv=v.df,
w=FormParameters(**W)) * dx, axis=1)

return ClassicBilinearForm(form)
13 changes: 0 additions & 13 deletions skfem/assembly/form/form_parameters.py

This file was deleted.

23 changes: 1 addition & 22 deletions skfem/assembly/form/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict
from typing import Dict

from numpy import ndarray
import numpy as np
Expand Down Expand Up @@ -31,24 +31,3 @@ def assemble(self,
v: Basis,
**kwargs) -> float:
return np.sum(self.elemental(v, **kwargs))


def functional(form: Callable) -> Functional:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @functional wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@Functional.", DeprecationWarning)

class ClassicFunctional(Functional):

def _kernel(self, w, dx):
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
return np.sum(self.form(w=FormParameters(**W)) * dx, axis=1)

return ClassicFunctional(form)
29 changes: 1 addition & 28 deletions skfem/assembly/form/linear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Optional

import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -43,30 +43,3 @@ def assemble(self,

def _kernel(self, v, w, dx):
return np.sum(self.form(*v, w) * dx, axis=1)


def linear_form(form: Callable) -> LinearForm:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @linear_form wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@LinearForm.", DeprecationWarning)

class ClassicLinearForm(LinearForm):

def _kernel(self, v, w, dx):
v = v[0]
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
if v.ddf is not None:
return np.sum(self.form(v=v.f, dv=v.df, ddv=v.ddf,
w=FormParameters(**W)) * dx, axis=1)
else:
return np.sum(self.form(v=v.f, dv=v.df,
w=FormParameters(**W)) * dx, axis=1)

return ClassicLinearForm(form)
17 changes: 9 additions & 8 deletions tests/test_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import numpy as np

from skfem import (BilinearForm, LinearForm, Functional, asm, bilinear_form,
linear_form, solve, functional)
from skfem import BilinearForm, LinearForm, Functional, asm, solve
from skfem.element import (ElementQuad1, ElementQuadS2, ElementHex1,
ElementHexS2, ElementTetP0, ElementTetP1,
ElementTetP2, ElementTriP1, ElementQuad2,
Expand Down Expand Up @@ -163,16 +162,18 @@ class BasisInterpolatorQuadS2(BasisInterpolator):


class BasisInterpolatorMorley(BasisInterpolator):

case = (MeshTri, ElementTriMorley)

def initOnes(self, basis):
@bilinear_form
def mass(u, du, ddu, v, dv, ddv, w):

@BilinearForm
def mass(u, v, w):
return u * v

@linear_form
def ones(v, dv, ddv, w):
return 1.0 * v
@LinearForm
def ones(v, w):
return 1. * v

M = asm(mass, basis)
f = asm(ones, basis)
Expand Down Expand Up @@ -256,7 +257,7 @@ def runTest(self):
e = ElementQuad1()
basis = InteriorBasis(m, e)

@functional
@Functional
def x_squared(w):
return w.x[0] ** 2

Expand Down