Skip to content

Commit

Permalink
Merge pull request #139 from DedalusProject/lu_transpose
Browse files Browse the repository at this point in the history
Transposed LU factorizations
  • Loading branch information
kburns authored May 12, 2021
2 parents fa9ce7f + 85ee999 commit e1541b9
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 49 deletions.
7 changes: 4 additions & 3 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__.split('.')[-1])
DEFAULT_LIBRARY = config['transforms'].get('DEFAULT_LIBRARY')
FFTW_RIGOR = config['transforms-fftw'].get('PLANNING_RIGOR')
DIRICHLET_PRECONDITIONING = lambda: config['matrix construction'].getboolean('DIRICHLET_PRECONDITIONING')


class Basis:
Expand Down Expand Up @@ -299,7 +300,7 @@ def __init__(self, name, base_grid_size, interval=(-1,1), dealias=1, tau_after_p

def default_meta(self):
return {'constant': False,
'dirichlet': True}
'dirichlet': DIRICHLET_PRECONDITIONING()}

@CachedMethod
def grid(self, scale=1.):
Expand Down Expand Up @@ -657,7 +658,7 @@ def __init__(self, name, base_grid_size, interval=(-1,1), dealias=1, tau_after_p

def default_meta(self):
return {'constant': False,
'dirichlet': True}
'dirichlet': DIRICHLET_PRECONDITIONING()}

@CachedMethod
def grid(self, scale=1.):
Expand Down Expand Up @@ -1392,7 +1393,7 @@ def __init__(self, name, base_grid_size, edge=0.0, stretch=1.0, dealias=1, tau_a

def default_meta(self):
return {'constant': False,
'dirichlet': True,
'dirichlet': DIRICHLET_PRECONDITIONING(),
'envelope': True}

@CachedMethod
Expand Down
131 changes: 89 additions & 42 deletions dedalus/core/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,34 +278,32 @@ def _build_coupled_matrices(self, problem, names, cacheid=None):
LHS_blocks[name].append(eq_blocks)

# Combine blocks
left_perm = left_permutation(zbasis, n_vars, pencil_eqs)
right_perm = right_permutation(zbasis, problem)
left_perm = left_permutation(zbasis, n_vars, pencil_eqs, bc_top=problem.BC_TOP, interleave_subbases=problem.INTERLEAVE_SUBBASES)
right_perm = right_permutation(zbasis, problem, interleave_subbases=problem.INTERLEAVE_SUBBASES)
self.pre_left = left_perm @ sparse.block_diag(pre_left_diags, format='csr', dtype=zdtype)
self.pre_right = sparse.block_diag(pre_right_diags, format='csr', dtype=zdtype) @ right_perm
LHS_matrices = {name: left_perm @ fast_bmat(LHS_blocks[name]).tocsr() for name in names}

# Store minimal-entry matrices for fast dot products
for name, matrix in LHS_matrices.items():
# Store full matrix
matrix.eliminate_zeros()
setattr(self, name+'_full', matrix.tocsr().copy())
# Truncate entries
matrix.data[np.abs(matrix.data) < problem.entry_cutoff] = 0
matrix.eliminate_zeros()
# Store truncated matrix
setattr(self, name, matrix.tocsr().copy())

# Store expanded right-preconditioned matrices
# Apply right preconditioning
if self.pre_right is not None:
for name in names:
LHS_matrices[name] = LHS_matrices[name] @ self.pre_right
# Build expanded LHS matrix to store matrix combinations
self.LHS = zeros_with_pattern(*LHS_matrices.values()).tocsr()
# Store expanded matrices for fast combination
for name, matrix in LHS_matrices.items():
matrix = expand_pattern(matrix, self.LHS)
setattr(self, name+'_exp', matrix.tocsr().copy())
if problem.STORE_EXPANDED_MATRICES:
# Apply right preconditioning
if self.pre_right is not None:
for name in names:
LHS_matrices[name] = LHS_matrices[name] @ self.pre_right
# Build expanded LHS matrix to store matrix combinations
self.LHS = zeros_with_pattern(*LHS_matrices.values()).tocsr()
# Store expanded matrices for fast combination
for name, matrix in LHS_matrices.items():
matrix = expand_pattern(matrix, self.LHS)
setattr(self, name+'_exp', matrix.tocsr().copy())


def fast_bmat(blocks):
Expand Down Expand Up @@ -352,15 +350,21 @@ def simple_reorder(N0, N1):
return sparse_perm(perm_indeces, len(perm_indeces)).tocsr()


def left_permutation(zbasis, n_vars, eqs):
def left_permutation(zbasis, n_vars, eqs, bc_top, interleave_subbases):
"""
Left permutation keeping match rows first, and inverting equation nesting:
Input: Equations > Subbases > modes
Output: Modes > Subbases > Equations
Left permutation acting on equations.
bc_top determines if constant equations are placed at the top or bottom of the matrix.
Input ordering:
Equations > Subbases > Modes
Output ordering with interleave_subbases=True:
Modes > Subbases > Equations
Output ordering with interleave_subbases=False:
Subbases > Modes > Equations
"""
# Compute hierarchy or input equation indeces
nmatch = n_vars * (len(zbasis.subbases) - 1)
# Compute list heirarchy of indeces
i = i0 = nmatch
i = nmatch
L0 = []
for eq in eqs:
L1 = []
Expand All @@ -382,30 +386,64 @@ def left_permutation(zbasis, n_vars, eqs):
i += 1
L1.append(L2)
L0.append(L1)
# Reverse list hierarchy
indeces = []
for i in range(i0):
indeces.append(i)
# Match indeces
match_indeces = []
for i in range(nmatch):
match_indeces.append(i)
n1max = len(L0)
n2max = max(len(L1) for L1 in L0)
n3max = max(len(L2) for L1 in L0 for L2 in L1)
for n3 in range(n3max):
# Constant and nonconstant equation indeces
const_indeces = []
nonconst_indeces = []
if interleave_subbases:
for n3 in range(n3max):
for n2 in range(n2max):
for n1 in range(n1max):
if eqs[n1]['LHS'].meta[zbasis.name]['constant']:
try:
const_indeces.append(L0[n1][n2][n3])
except IndexError:
continue
else:
try:
nonconst_indeces.append(L0[n1][n2][n3])
except IndexError:
continue
else:
for n2 in range(n2max):
for n1 in range(n1max):
try:
indeces.append(L0[n1][n2][n3])
except IndexError:
continue
for n3 in range(n3max):
for n1 in range(n1max):
if eqs[n1]['LHS'].meta[zbasis.name]['constant']:
try:
const_indeces.append(L0[n1][n2][n3])
except IndexError:
continue
else:
try:
nonconst_indeces.append(L0[n1][n2][n3])
except IndexError:
continue
# Combine indeces
if bc_top:
indeces = match_indeces + const_indeces + nonconst_indeces
else:
indeces = nonconst_indeces + const_indeces + match_indeces
return sparse_perm(indeces, len(indeces)).T.tocsr()


def right_permutation(zbasis, problem):
def right_permutation(zbasis, problem, interleave_subbases):
"""
Right permutation inverting variable nesting:
Input: Variables > Subbases > modes
Output: Modes > Subbases > Variables
Right permutation acting on variables.
Input ordering:
Variables > Subbases > Modes
Output ordering with interleave_subbases=True:
Modes > Subbases > Variables
Output ordering with interleave_subbases=False:
Subbases > Modes > Variables
"""
# Compute list heirarchy of indeces
# Compute hierarchy or input variable indeces
i = 0
L0 = []
for var in problem.variables:
Expand All @@ -428,12 +466,21 @@ def right_permutation(zbasis, problem):
L1max = len(L0)
L2max = max(len(L1) for L1 in L0)
L3max = max(len(L2) for L1 in L0 for L2 in L1)
for n3 in range(L3max):
if interleave_subbases:
for n3 in range(L3max):
for n2 in range(L2max):
for n1 in range(L1max):
try:
indeces.append(L0[n1][n2][n3])
except IndexError:
continue
else:
for n2 in range(L2max):
for n1 in range(L1max):
try:
indeces.append(L0[n1][n2][n3])
except IndexError:
continue
for n3 in range(L3max):
for n1 in range(L1max):
try:
indeces.append(L0[n1][n2][n3])
except IndexError:
continue
return sparse_perm(indeces, len(indeces)).tocsr()

9 changes: 9 additions & 0 deletions dedalus/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from ..tools.exceptions import SymbolicParsingError
from ..tools.exceptions import UnsupportedEquationError

from ..tools.config import config
BC_TOP = lambda: config['matrix construction'].getboolean('BC_TOP')
INTERLEAVE_SUBBASES = lambda: config['matrix construction'].getboolean('INTERLEAVE_SUBBASES')
STORE_EXPANDED_MATRICES = lambda: config['matrix construction'].getboolean('STORE_EXPANDED_MATRICES')

import logging
logger = logging.getLogger(__name__.split('.')[-1])

Expand Down Expand Up @@ -113,6 +118,10 @@ def __init__(self, domain, variables, ncc_cutoff=1e-6, max_ncc_terms=None, entry
self.ncc_kw = {'cutoff': ncc_cutoff, 'max_terms': max_ncc_terms}
self.entry_cutoff = entry_cutoff
self.coupled = domain.bases[-1].coupled
# Matrix construction config options
self.BC_TOP = BC_TOP()
self.INTERLEAVE_SUBBASES = INTERLEAVE_SUBBASES()
self.STORE_EXPANDED_MATRICES = STORE_EXPANDED_MATRICES()

@property
def nvars_const(self):
Expand Down
16 changes: 14 additions & 2 deletions dedalus/core/timesteppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def step(self, solver, dt):
pencils = solver.pencils
evaluator = solver.evaluator
state = solver.state
STORE_EXPANDED_MATRICES = solver.problem.STORE_EXPANDED_MATRICES

evaluator_kw = {}
evaluator_kw['world_time'] = world_time = solver.get_world_time()
Expand Down Expand Up @@ -152,7 +153,12 @@ def step(self, solver, dt):
state.data.fill(0)
for p in pencils:
if update_LHS:
np.copyto(p.LHS.data, a0*p.M_exp.data + b0*p.L_exp.data) # CREATES TEMPORARY
if STORE_EXPANDED_MATRICES:
np.copyto(p.LHS.data, a0*p.M_exp.data + b0*p.L_exp.data) # CREATES TEMPORARY
else:
p.LHS = (a0*p.M + b0*p.L) @ p.pre_right
# Remove old solver reference before building new solver
p.LHS_solver = None
p.LHS_solver = solver.matsolver(p.LHS, solver)
pRHS = RHS.get_pencil(p)
pX = p.LHS_solver.solve(pRHS)
Expand Down Expand Up @@ -524,6 +530,7 @@ def step(self, solver, dt):
pencils = solver.pencils
evaluator = solver.evaluator
state = solver.state
STORE_EXPANDED_MATRICES = solver.problem.STORE_EXPANDED_MATRICES

evaluator_kw = {}
evaluator_kw['world_time'] = world_time = solver.get_world_time()
Expand Down Expand Up @@ -587,7 +594,12 @@ def step(self, solver, dt):
for p in pencils:
# Construct LHS(n,i)
if update_LHS:
np.copyto(p.LHS.data, p.M_exp.data + (k*H[i,i])*p.L_exp.data) # CREATES TEMPORARY
if STORE_EXPANDED_MATRICES:
np.copyto(p.LHS.data, p.M_exp.data + (k*H[i,i])*p.L_exp.data) # CREATES TEMPORARY
else:
p.LHS = (p.M + (k*H[i,i])*p.L) @ p.pre_right
# Remove old solver reference before building new solver
p.LHS_solvers[i] = None
p.LHS_solvers[i] = solver.matsolver(p.LHS, solver)
pRHS = RHS.get_pencil(p)
pX = p.LHS_solvers[i].solve(pRHS) # CREATES TEMPORARY
Expand Down
19 changes: 18 additions & 1 deletion dedalus/dedalus.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,30 @@
# Use variable-length all-to-all routine
ALLTOALLV = False

[matrix construction]

# Dirichlet preconditioning default
DIRICHLET_PRECONDITIONING = False

# Put BC rows at the top of the matrix
# Set to True when using Dirichlet preconditioning
BC_TOP = False

# Interleave subbasis modes
# Set to True when using Dirichlet preconditioning
INTERLEAVE_SUBBASES = False

# Store expanded LHS matrices
# May speed up matrix factorization at the expense of extra memory
STORE_EXPANDED_MATRICES = True

[linear algebra]

# Default sparse matrix solver for single solves
MATRIX_SOLVER = SuperLUNaturalSpsolve

# Default sparse matrix factorizer for repeated solves
MATRIX_FACTORIZER = SuperLUNaturalFactorized
MATRIX_FACTORIZER = SuperLUNaturalFactorizedTranspose

[memory]

Expand Down
11 changes: 11 additions & 0 deletions dedalus/libraries/matsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ def solve(self, vector):
return self.LU.solve(vector)


@add_solver
class SuperluNaturalFactorizedTranspose(SparseSolver):
"""SuperLU+NATURAL LU factorized solve."""

def __init__(self, matrix, solver=None):
self.LU = spla.splu(matrix.T.tocsc(), permc_spec='NATURAL')

def solve(self, vector):
return self.LU.solve(vector, trans='T')


@add_solver
class SuperluColamdFactorized(SparseSolver):
"""SuperLU+COLAMD LU factorized solve."""
Expand Down
1 change: 0 additions & 1 deletion examples/ivp/2d_rayleigh_benard/rayleigh_benard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@

# 2D Boussinesq hydrodynamics
problem = de.IVP(domain, variables=['p','b','u','w','bz','uz','wz'])
problem.meta['p','b','u','w']['z']['dirichlet'] = True
problem.parameters['P'] = (Rayleigh * Prandtl)**(-1/2)
problem.parameters['R'] = (Rayleigh / Prandtl)**(-1/2)
problem.parameters['F'] = F = 1
Expand Down

0 comments on commit e1541b9

Please sign in to comment.