Skip to content

Commit

Permalink
renamed options cudss
Browse files Browse the repository at this point in the history
  • Loading branch information
sshin23 committed Mar 2, 2024
1 parent 7a66bf6 commit 770de36
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions lib/MadNLPGPU/src/cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import CUDSS
import SparseArrays

@kwdef mutable struct CudssSolverOptions <: MadNLP.AbstractOptions
cudss_algorithm::MadNLP.LinearFactorization = MadNLP.BUNCHKAUFMAN
cudss_algorithm::MadNLP.LinearFactorization = MadNLP.LDL
end

mutable struct CUDSSSolver{T} <: MadNLP.AbstractLinearSolver{T}
Expand All @@ -28,7 +28,7 @@ function CUDSSSolver(
"G"
elseif opt.cudss_algorithm == MadNLP.CHOLESKY
"SPD"
elseif opt.cudss_algorithm == MadNLP.BUNCHKAUFMAN
elseif opt.cudss_algorithm == MadNLP.LDL
"S"
end

Expand Down Expand Up @@ -73,7 +73,7 @@ end

MadNLP.input_type(::Type{CUDSSSolver}) = :csc
MadNLP.default_options(::Type{CUDSSSolver}) = CudssSolverOptions()
MadNLP.is_inertia(M::CUDSSSolver) = (M.opt.cudss_algorithm (MadNLP.CHOLESKY, MadNLP.BUNCHKAUFMAN))
MadNLP.is_inertia(M::CUDSSSolver) = (M.opt.cudss_algorithm (MadNLP.CHOLESKY, MadNLP.LDL))
function inertia(M::CUDSSSolver)
n = size(M.tril, 1)
if M.opt.cudss_algorithm == MadNLP.CHOLESKY
Expand All @@ -83,7 +83,7 @@ function inertia(M::CUDSSSolver)
else
return (0, n, 0)
end
elseif M.opt.cudss_algorithm == MadNLP.BUNCHKAUFMAN
elseif M.opt.cudss_algorithm == MadNLP.LDL
# N.B.: cuDSS does not always return the correct inertia.
(k, l) = CUDSS.cudss_get(M.inner, "inertia")
k = min(n, k) # TODO: add safeguard for inertia
Expand Down
4 changes: 3 additions & 1 deletion src/LinearSolvers/linearsolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ struct SolveException <: Exception end
struct InertiaException <: Exception end
LinearSolverException=Union{SymbolicException,FactorizationException,SolveException,InertiaException}

@enum(LinearFactorization::Int,
@enum(
LinearFactorization::Int,
BUNCHKAUFMAN = 1,
LU = 2,
QR = 3,
CHOLESKY = 4,
LDL = 5,
)

# iterative solvers
Expand Down

0 comments on commit 770de36

Please sign in to comment.