Skip to content

Commit

Permalink
creating functions to avoid superfluous computation of 0 gradient val…
Browse files Browse the repository at this point in the history
…ues for non-local ivs; cleaning up unit test harness
  • Loading branch information
ralberd committed Nov 2, 2023
1 parent 9956926 commit 112b10b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 47 deletions.
35 changes: 35 additions & 0 deletions optimism/inverse/test/MechanicsInverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from collections import namedtuple

from optimism.JaxConfig import *
from optimism import Mechanics
from optimism import FunctionSpace

MechanicsInverseFunctions = namedtuple('MechanicsInverseFunctions',
['partial_ivs_update_partial_ivs_prev'])

def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_state_new, gradient_shape):
dgQuadPointRavel = dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:])
stQuadPointRavel = states.reshape(states.shape[0]*states.shape[1],*states.shape[2:])
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt)
return statesNew.reshape(gradient_shape)

def create_mechanics_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None):
fs = functionSpace

if mode2D == 'plane strain':
grad_2D_to_3D = Mechanics.plane_strain_gradient_transformation
elif mode2D == 'axisymmetric':
raise NotImplementedError

modify_element_gradient = grad_2D_to_3D
if pressureProjectionDegree is not None:
raise NotImplementedError

def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=0.0):
dispGrads = FunctionSpace.compute_field_gradient(fs, U, modify_element_gradient)
update_gradient = jacfwd(materialModel.compute_state_new, argnums=1)
grad_shape = stateVariables.shape + (stateVariables.shape[2],)
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_gradient, grad_shape)

return MechanicsInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev))
160 changes: 113 additions & 47 deletions optimism/inverse/test/test_J2Plastic_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from optimism.test.TestFixture import TestFixture
from optimism.test.MeshFixture import MeshFixture

# misc. modules in test directory for now while I test
import MechanicsInverse

def make_disp_grad_from_strain(strain):
return linalg.expm(strain) - np.identity(3)

Expand Down Expand Up @@ -73,7 +76,6 @@ def setUp(self):
self.compute_state_new = jax.jit(materialModel.compute_state_new)
self.compute_initial_state = materialModel.compute_initial_state

# self.compute_state_new_derivs = jax.jit(jax.jacrev(self.compute_state_new, (0, 1)))
self.compute_state_new_derivs = jax.jit(jax.jacfwd(self.compute_state_new, (0, 1)))

def test_jax_computation_of_state_derivs_at_elastic_step(self):
Expand Down Expand Up @@ -121,56 +123,63 @@ def test_jax_computation_of_state_derivs_at_plastic_step(self):


class J2GlobalMeshUpdateGradsFixture(MeshFixture):
def setUp(self):
@classmethod
def setUpClass(cls):
dispGrad0 = np.array([[0.4, -0.2],
[-0.04, 0.68]])
self.mesh, self.U = self.create_mesh_and_disp(4,4,[0.,1.],[0.,1.],
mf = MeshFixture()
cls.mesh, cls.U = mf.create_mesh_and_disp(4,4,[0.,1.],[0.,1.],
lambda x: dispGrad0@x)

E = 100.0
poisson = 0.321
H = 1e-2 * E
Y0 = 0.3 * E

self.props = {'elastic modulus': E,
cls.props = {'elastic modulus': E,
'poisson ratio': poisson,
'yield strength': Y0,
'kinematics': 'small deformations',
'hardening model': 'linear',
'hardening modulus': H}

self.materialModel = J2.create_material_model_functions(self.props)

self.quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=1)

self.fs = FunctionSpace.construct_function_space(self.mesh, self.quadRule)

self.mechFuncs = Mechanics.create_mechanics_functions(self.fs,
cls.materialModel = J2.create_material_model_functions(cls.props)
cls.quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=1)
cls.fs = FunctionSpace.construct_function_space(cls.mesh, cls.quadRule)
cls.mechFuncs = Mechanics.create_mechanics_functions(cls.fs,
"plane strain",
self.materialModel)
cls.materialModel)
cls.ivs_prev = cls.mechFuncs.compute_initial_state()

EBCs = [FunctionSpace.EssentialBC(nodeSet='all_boundary', component=0),
FunctionSpace.EssentialBC(nodeSet='all_boundary', component=1)]
cls.dofManager = FunctionSpace.DofManager(cls.fs, 2, EBCs)
cls.Ubc = cls.dofManager.get_bc_values(cls.U)

self.dofManager = FunctionSpace.DofManager(self.fs, 2, EBCs)
p = Objective.Params(None, cls.ivs_prev, None, None, None)
UuGuess = 0.0*cls.dofManager.get_unknown_values(cls.U)

self.Ubc = self.dofManager.get_bc_values(self.U)

def test_state_derivs_at_elastic_step(self):
def compute_energy(Uu, p):
U = cls.dofManager.create_field(Uu, cls.Ubc)
internalVariables = p[1]
return cls.mechFuncs.compute_strain_energy(U, internalVariables)

internalVariables = self.mechFuncs.compute_initial_state()
objective = Objective.Objective(compute_energy, UuGuess, p)
cls.Uu = EqSolver.nonlinear_equation_solve(objective, UuGuess, p, EqSolver.get_settings(), useWarmStart=False)
U = cls.dofManager.create_field(cls.Uu, cls.Ubc)
cls.ivs = cls.mechFuncs.compute_updated_internal_variables(U, cls.ivs_prev)

p = Objective.Params(None, internalVariables, None, None, None)
def test_state_derivs_at_elastic_step(self):

def update_internal_vars_test(Uu, internalVars):
def update_internal_vars_test(Uu, ivs_prev):
U = self.dofManager.create_field(Uu)
internalVariablesNew = self.mechFuncs.compute_updated_internal_variables(U, internalVars)
return internalVariablesNew
ivs = self.mechFuncs.compute_updated_internal_variables(U, ivs_prev)
return ivs

Uu = 0.0*self.dofManager.get_unknown_values(self.U)

update_internal_variables_derivs = jax.jacfwd(update_internal_vars_test, (0,1))
dc_du, dc_dc_n = update_internal_variables_derivs(Uu, p[1])
dc_du, dc_dc_n = update_internal_variables_derivs(Uu, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
Expand All @@ -187,40 +196,28 @@ def update_internal_vars_test(Uu, internalVars):

def test_state_derivs_at_plastic_step(self):

initialInternalVariables = self.mechFuncs.compute_initial_state()

p = Objective.Params(None, initialInternalVariables, None, None, None)

def update_internal_vars_test(U, internalVars):
internalVariablesNew = self.mechFuncs.compute_updated_internal_variables(U, internalVars)
return internalVariablesNew

def compute_energy(Uu, p):
U = self.dofManager.create_field(Uu, self.Ubc)
internalVariables = p[1]
return self.mechFuncs.compute_strain_energy(U, internalVariables)

UuGuess = 0.0*self.dofManager.get_unknown_values(self.U)
objective = Objective.Objective(compute_energy, UuGuess, p)
Uu = EqSolver.nonlinear_equation_solve(objective, UuGuess, p, EqSolver.get_settings(), useWarmStart=False)
U = self.dofManager.create_field(Uu, self.Ubc)

internalVariables = update_internal_vars_test(U, p[1])
def update_internal_vars_test(U, ivs_prev):
ivs = self.mechFuncs.compute_updated_internal_variables(U, ivs_prev)
return ivs

update_internal_variables_derivs = jax.jacfwd(update_internal_vars_test, (0,1))
dc_du, dc_dc_n = update_internal_variables_derivs(U, p[1])

U = self.dofManager.create_field(self.Uu, self.Ubc)
dc_du, dc_dc_n = update_internal_variables_derivs(U, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
nIntVars = 10
nDims = 2
nNodes = Mesh.num_nodes(self.mesh)

self.assertEqual(dc_du.shape, (nElems,nQpsPerElem,nIntVars,U.shape[0],U.shape[1]))
self.assertEqual(dc_du.shape, (nElems,nQpsPerElem,nIntVars,nNodes,nDims))
self.assertEqual(dc_dc_n.shape, (nElems,nQpsPerElem,nIntVars,nElems,nQpsPerElem,nIntVars))

for i in range(0,nElems):
for j in range(0,nQpsPerElem):
state = internalVariables[i,j,:]
initial_state = initialInternalVariables[i,j,:]
state = self.ivs[i,j,:]
initial_state = self.ivs_prev[i,j,:]

conn = self.mesh.conns[i]
Uele = U[conn]
Expand All @@ -233,11 +230,80 @@ def compute_energy(Uu, p):

dc_dugrad_gold, dc_dc_n_gold = small_strain_linear_hardening_analytic_gradients(dispGrad, state, initial_state, self.props)

self.assertArrayNear(dc_dc_n[i,j,:,i,j,:].ravel(), dc_dc_n_gold.ravel(), 10)

dc_duele_gold = np.tensordot(dc_dugrad_gold, Be_mat, axes=1)
dc_er_du = dc_du[i,j,:,:,:]

self.assertArrayNear(dc_er_du[:,conn,:].ravel(), dc_duele_gold.reshape(10,3,2).ravel(), 10)
self.assertArrayNear(np.delete(dc_er_du, conn, axis=1).ravel(), np.zeros((10,nNodes-conn.shape[0],2)).ravel(), 10)

for p in range(0,nElems):
for q in range(0,nQpsPerElem):
if(i == p and j == q):
self.assertArrayNear(dc_dc_n[i,j,:,i,j,:].ravel(), dc_dc_n_gold.ravel(), 10)
else:
self.assertArrayNear(dc_dc_n[i,j,:,p,q,:].ravel(), np.zeros((nIntVars,nIntVars)).ravel(), 10)

def test_state_derivs_computed_locally_at_plastic_step(self):

mechInverseFuncs = MechanicsInverse.create_mechanics_inverse_functions(self.fs,
"plane strain",
self.materialModel)

U = self.dofManager.create_field(self.Uu, self.Ubc)
dc_dc_n = mechInverseFuncs.partial_ivs_update_partial_ivs_prev(U, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
nIntVars = 10

self.assertEqual(dc_dc_n.shape, (nElems,nQpsPerElem,nIntVars,nIntVars))

for i in range(0,nElems):
for j in range(0,nQpsPerElem):
state = self.ivs[i,j,:]
initial_state = self.ivs_prev[i,j,:]

conn = self.mesh.conns[i]
Uele = U[conn]
shapeGrads = self.fs.shapeGrads[i,j,:,:]
dispGrad = TensorMath.tensor_2D_to_3D(np.tensordot(Uele,shapeGrads,axes=[0,0]))

_, dc_dc_n_gold = small_strain_linear_hardening_analytic_gradients(dispGrad, state, initial_state, self.props)

self.assertArrayNear(dc_dc_n[i,j,:,:].ravel(), dc_dc_n_gold.ravel(), 10)

def test_internal_variables_updates_jacobian_vector_products(self):

def energy_function_ravel(Uu, ivs):
internal_vars = ivs.reshape(self.ivs.shape)
U = self.dofManager.create_field(Uu, self.Ubc)
return self.mechFuncs.compute_strain_energy(U, internal_vars)

def energy_function(Uu, ivs):
U = self.dofManager.create_field(Uu, self.Ubc)
return self.mechFuncs.compute_strain_energy(U, ivs)

def update_internal_vars_test(Uu, ivs_prev):
internal_vars = ivs_prev.reshape(self.ivs.shape)
U = self.dofManager.create_field(Uu, self.Ubc)
return self.mechFuncs.compute_updated_internal_variables(U, internal_vars).ravel()

mechInverseFuncs = MechanicsInverse.create_mechanics_inverse_functions(self.fs,
"plane strain",
self.materialModel)

U = self.dofManager.create_field(self.Uu, self.Ubc)

key = jax.random.PRNGKey(0)
mu_dummy = jax.random.uniform(key, (np.prod(np.array(self.ivs.shape)),))

dc_dc_n_raveled = jax.jacfwd(update_internal_vars_test, 1)(self.Uu, self.ivs_prev.ravel())
prodGold = np.tensordot(mu_dummy, dc_dc_n_raveled, axes=1)

dc_dc_n = mechInverseFuncs.partial_ivs_update_partial_ivs_prev(U, self.ivs_prev)
prodReduced = np.einsum('ijk,ijkn->ijn', mu_dummy.reshape(self.ivs.shape), dc_dc_n)

self.assertArrayNear(prodGold.ravel(), prodReduced.ravel(), 12)



Expand Down

0 comments on commit 112b10b

Please sign in to comment.