Skip to content

Commit

Permalink
Merge pull request #457 from kinnala/remove-old-style-forms
Browse files Browse the repository at this point in the history
Remove old style forms
  • Loading branch information
kinnala authored Aug 14, 2020
2 parents 4aa87e1 + e4c521c commit 5d07cad
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 119 deletions.
18 changes: 10 additions & 8 deletions docs/examples/ex22.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
from skfem import *
from skfem.models.poisson import laplace
from skfem.helpers import grad
import numpy as np

m = MeshTri.init_lshaped()
Expand All @@ -32,27 +33,28 @@ def eval_estimator(m, u):
# interior residual
basis = InteriorBasis(m, e)

@functional
@Functional
def interior_residual(w):
h = w.h
x, y = w.x
return h**2 * load_func(x, y)**2
return h ** 2 * load_func(x, y) ** 2

eta_K = interior_residual.elemental(basis, w=basis.interpolate(u))

# facet jump
fbasis = [FacetBasis(m, e, side=i) for i in [0, 1]]
w = [fbasis[i].interpolate(u) for i in [0, 1]]
w = {'u' + str(i + 1): fbasis[i].interpolate(u) for i in [0, 1]}

@functional
@Functional
def edge_jump(w):
h = w.h
n = w.n
du1, du2 = w.dw
return h * ((du1[0] - du2[0])*n[0] +\
(du1[1] - du2[1])*n[1])**2
dw1 = grad(w['u1'])
dw2 = grad(w['u2'])
return h * ((dw1[0] - dw2[0]) * n[0] +\
(dw1[1] - dw2[1]) * n[1]) ** 2

eta_E = edge_jump.elemental(fbasis[0], w=w)
eta_E = edge_jump.elemental(fbasis[0], **w)

tmp = np.zeros(m.facets.shape[1])
np.add.at(tmp, fbasis[0].find, eta_E)
Expand Down
8 changes: 2 additions & 6 deletions skfem/assembly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@

from .basis import Basis, InteriorBasis, FacetBasis
from .dofs import Dofs, DofsView
from .form import Form, BilinearForm, LinearForm, Functional,\
bilinear_form, linear_form, functional
from .form import Form, BilinearForm, LinearForm, Functional


def asm(form: Form,
Expand All @@ -78,7 +77,4 @@ def asm(form: Form,
"DofsView",
"BilinearForm",
"LinearForm",
"Functional",
"bilinear_form",
"linear_form",
"functional"]
"Functional"]
6 changes: 3 additions & 3 deletions skfem/assembly/form/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .form import Form # noqa
from .bilinear_form import bilinear_form, BilinearForm # noqa
from .linear_form import linear_form, LinearForm # noqa
from .functional import functional, Functional # noqa
from .bilinear_form import BilinearForm # noqa
from .linear_form import LinearForm # noqa
from .functional import Functional # noqa
32 changes: 1 addition & 31 deletions skfem/assembly/form/bilinear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Any
from typing import Optional, Any

import numpy as np

Expand Down Expand Up @@ -74,33 +74,3 @@ def assemble(self,

def _kernel(self, u, v, w, dx):
return np.sum(self.form(*u, *v, w) * dx, axis=1)


def bilinear_form(form: Callable) -> BilinearForm:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @bilinear_form wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@BilinearForm.", DeprecationWarning)

class ClassicBilinearForm(BilinearForm):

def _kernel(self, u, v, w, dx):
u = u[0]
v = v[0]
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
if u.ddf is not None:
return np.sum(self.form(u=u.f, du=u.df, ddu=u.ddf,
v=v.f, dv=v.df, ddv=v.ddf,
w=FormParameters(**W)) * dx, axis=1)
else:
return np.sum(self.form(u=u.f, du=u.df,
v=v.f, dv=v.df,
w=FormParameters(**W)) * dx, axis=1)

return ClassicBilinearForm(form)
13 changes: 0 additions & 13 deletions skfem/assembly/form/form_parameters.py

This file was deleted.

23 changes: 1 addition & 22 deletions skfem/assembly/form/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict
from typing import Dict

from numpy import ndarray
import numpy as np
Expand Down Expand Up @@ -31,24 +31,3 @@ def assemble(self,
v: Basis,
**kwargs) -> float:
return np.sum(self.elemental(v, **kwargs))


def functional(form: Callable) -> Functional:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @functional wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@Functional.", DeprecationWarning)

class ClassicFunctional(Functional):

def _kernel(self, w, dx):
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
return np.sum(self.form(w=FormParameters(**W)) * dx, axis=1)

return ClassicFunctional(form)
29 changes: 1 addition & 28 deletions skfem/assembly/form/linear_form.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Optional

import numpy as np
from numpy import ndarray
Expand Down Expand Up @@ -43,30 +43,3 @@ def assemble(self,

def _kernel(self, v, w, dx):
return np.sum(self.form(*v, w) * dx, axis=1)


def linear_form(form: Callable) -> LinearForm:

# for backwards compatibility
from .form_parameters import FormParameters

import warnings
warnings.warn("The old style @linear_form wrapper is deprecated. "
"Consider using the new style forms, defined via "
"@LinearForm.", DeprecationWarning)

class ClassicLinearForm(LinearForm):

def _kernel(self, v, w, dx):
v = v[0]
W = {k: w[k].f for k in w}
if 'w' in w:
W['dw'] = w['w'].df
if v.ddf is not None:
return np.sum(self.form(v=v.f, dv=v.df, ddv=v.ddf,
w=FormParameters(**W)) * dx, axis=1)
else:
return np.sum(self.form(v=v.f, dv=v.df,
w=FormParameters(**W)) * dx, axis=1)

return ClassicLinearForm(form)
17 changes: 9 additions & 8 deletions tests/test_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import numpy as np

from skfem import (BilinearForm, LinearForm, Functional, asm, bilinear_form,
linear_form, solve, functional)
from skfem import BilinearForm, LinearForm, Functional, asm, solve
from skfem.element import (ElementQuad1, ElementQuadS2, ElementHex1,
ElementHexS2, ElementTetP0, ElementTetP1,
ElementTetP2, ElementTriP1, ElementQuad2,
Expand Down Expand Up @@ -163,16 +162,18 @@ class BasisInterpolatorQuadS2(BasisInterpolator):


class BasisInterpolatorMorley(BasisInterpolator):

case = (MeshTri, ElementTriMorley)

def initOnes(self, basis):
@bilinear_form
def mass(u, du, ddu, v, dv, ddv, w):

@BilinearForm
def mass(u, v, w):
return u * v

@linear_form
def ones(v, dv, ddv, w):
return 1.0 * v
@LinearForm
def ones(v, w):
return 1. * v

M = asm(mass, basis)
f = asm(ones, basis)
Expand Down Expand Up @@ -256,7 +257,7 @@ def runTest(self):
e = ElementQuad1()
basis = InteriorBasis(m, e)

@functional
@Functional
def x_squared(w):
return w.x[0] ** 2

Expand Down

0 comments on commit 5d07cad

Please sign in to comment.