Skip to content

Commit

Permalink
add adaptivity set and misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Nov 9, 2024
1 parent 58baa98 commit 7afc20c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 64 deletions.
21 changes: 7 additions & 14 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
end

mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, #=F3,=# Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
FIRKMutableCache
u::uType
uprev::uType
Expand Down Expand Up @@ -550,7 +550,6 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
jac_config::JC
linsolve1::F1 #real
linsolve2::Vector{F2} #complex
#linres2::Vector{F3}
rtol::rTol
atol::aTol
dtprev::Dt
Expand All @@ -569,10 +568,8 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

min = alg.min_stages
max = alg.max_stages

num_stages = min
num_stages = alg.min_stages

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand All @@ -590,6 +587,9 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
end

c_prime = Vector{typeof(t)}(undef, max) #time stepping
for i in 1 : max
c_prime[i] = zero(t)
end

dw1 = zero(u)
ubuff = zero(u)
Expand Down Expand Up @@ -636,14 +636,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
linsolve2 = [
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (max - 1) ÷ 2]
#=
linres_tmp = dolinsolve(nothing, linsolve2[1]; A = W2[1], b = _vec(cubuff[1]), linu = _vec(dw2[1]))
linres2 = Vector{typeof(linres_tmp)}(undef , (max - 1) ÷ 2)
linres2[1] = linres_tmp
for i in 2 : (num_stages - 1) ÷ 2
linres2[i] = dolinsolve(nothing, linsolve2[1]; A = W2[1], b = _vec(cubuff[i]), linu = _vec(dw2[i]))
end
=#

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

Expand All @@ -653,7 +646,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
atmp, jac_config,
linsolve1, linsolve2, #=linres2,=# rtol, atol, dt, dt,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
end

15 changes: 6 additions & 9 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,7 @@ end
@unpack κ, cont, derivatives, z, w, c_prime = cache
@unpack dw1, ubuff, dw2, cubuff, dw = cache
@unpack ks, k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, #=linres2,=# rtol, atol, step_limiter! = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -1635,14 +1635,14 @@ end
c_prime[i] = c[i] * c_prime[num_stages]
end
for i in 1 : num_stages # collocation polynomial
z[i] = cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
@.. z[i] = cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
j = num_stages - 2
while j > 0
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
@.. z[i] += cont[j]
j = j - 1
end
z[i] = z[i] * c_prime[i]
@.. z[i] *= c_prime[i]
end
#mul!(w, TI, z)
for i in 1:num_stages
Expand Down Expand Up @@ -1672,7 +1672,7 @@ end

#mul!(fw, TI, ks)
for i in 1:num_stages
fw[i] = zero(u)
@.. fw[i] = zero(u)
for j in 1:num_stages
@.. fw[i] += TI[i,j] * ks[j]
end
Expand Down Expand Up @@ -1704,17 +1704,14 @@ end

cache.linsolve1 = linres.cache

linres2 = Vector{Any}(undef,(num_stages - 1) ÷ 2)

for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
if needfactor
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i]))
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
else
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i]))
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
end
cache.linsolve2[i] = linres2[i].cache
end

integrator.stats.nsolve += (num_stages + 1) / 2
Expand Down
102 changes: 61 additions & 41 deletions lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ function BigRadauIIA5Tableau(T1, T2)
c[3] = big"1"

e = Vector{T1}(undef, 3)
e[1] = big"-10.0488093998274155624603295076470799145872107881988"
e[2] = big"1.38214273316074889579366284098041324792054412153223"
e[3] = big"-0.333333333333333333333333333333333333333333333333333"
e[1] = big"-10.0488093998274155624603295076470799145872107881988969663429493235855742140670683952596720105774938812433874028620997746246706860729547671304601625528869782"
e[2] = big"1.38214273316074889579366284098041324792054412153223029967628265691890754740040172859300534391082721457672073619543310795800401940628810046379349588622031217"
e[3] = big"-0.333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333644"

TI = Matrix{T1}(undef, 3, 3)
TI[1, 1] = big"4.32557989006315535102435095295614882731995158490590784287320458848019483341979047442263696495019938973156007686663488090615420049217658854859024016717169837"
Expand Down Expand Up @@ -180,11 +180,11 @@ function BigRadauIIA9Tableau(T1, T2)
c[5] = big"1.0"

e = Vector{T1}(undef, 5)
e[1] = big"-27.78093394406463730479"
e[2] = big"3.641478498049213152712"
e[3] = big"-1.252547721169118720491"
e[4] = big"0.5920031671845428725662"
e[5] = big"-0.2000000000000000000000"
e[1] = big"-27.7809339440646373047872078172168798923674228687740760060378492475924178050505976287227228556471699142365371740120443650701118024100678675823465762727483305"
e[2] = big"3.64147849804921315271165508774289722904088750334220956841022786858917594981395319605788667956024462601802006251583142928630101075351336314632135787805261686"
e[3] = big"-1.25254772116911872049065249430114914889315244289570569309128740586057170336299694248256681515155624683225624015343224399700466177251702555220815764199263189"
e[4] = big"0.592003167184542872566205223775131812219687808327572130718908784863813558599641375147402991238481535050773351649645179780815453429071529988233376036688329872"
e[5] = big"-0.199999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999997076"

TI = Matrix{T1}(undef, 5, 5)
TI[1, 1] = big"30.0415677215444016277146611632467970747634862837368422955138470463852339244593400023985957753164599415374157317627305099177616927640413043608408838747985125"
Expand Down Expand Up @@ -414,13 +414,13 @@ function BigRadauIIA13Tableau(T1, T2)
c[7] = big"1.0"

e = Vector{T1}(undef, 7)
e[1] = big"-54.37443689412861451458"
e[2] = big"7.000024004259186512041"
e[3] = big"-2.355661091987557192256"
e[4] = big"1.132289066106134386384"
e[5] = big"-0.6468913267673587118673"
e[6] = big"0.3875333853753523774248"
e[7] = big"-0.1428571428571428571429"
e[1] = big"-54.374436894128614514583710369683221528326818668136315170227649609831132483812209590903458627819914413600703287942266678601263304348350182019714004102122958"
e[2] = big"7.00002400425918651204068363735192307633403862621907697222411411256593188888314837387690262103761082115674234000933589934965063951414231971808906314491204573"
e[3] = big"-2.35566109198755719225604586775720723211163199654640573606711168106849118084357027539414093812951288166804790294091903523762277368547775099880612390898224076"
e[4] = big"1.13228906610613438638449290827978318662460499026070073842612187085281352278780837966549916347601259689966925986653914736463076068138934273474363230390185871"
e[5] = big"-0.646891326767358711867345222439989069591870662562921671446738173180691199552327090727940249497816198076028398716990245669520129053944261569921119452534594627"
e[6] = big"0.387533385375352377424782057105854424214534853623007724234120623518712309680007346340280888076477218145510846867158055651267664035097674992751409157682864641"
e[7] = big"-0.142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857142857092806"

TI = Matrix{T1}(undef, 7, 7)
TI[1, 1] = big"258.131926319982229276108947425184471333411128774462923076434633414645220927977539758484670571338176678808837829326061674950321562391576244286310404028770676"
Expand Down Expand Up @@ -599,34 +599,54 @@ function adaptiveRadauTableau(T1, T2, num_stages::Int)

if (num_stages == 9)
e = Vector{BigFloat}(undef, 9)
e[1] = big"-0.133101731359431287515066981129913748644705107621439651956220186897253838380345034218235538734967567153163030284540660584311040323114847240173627907922903296"
e[2] = big"0.0754476228408557299650196603226967248368445025181771896522057250989188754588885465998346476425502117889420021664297319179240040109156780754680742172762707621"
e[3] = big"-0.0458369394236156144604575482137179697005739995740615341890112217655441769701945378217626766299683076189687755618065050383493055018324395934911567207485032988"
e[4] = big"0.0271430329153098694457979735602502142083095152399102869109830450899844979409229538982100527256348792152825816553434603418662939944133319974874915933773657075"
e[5] = big"-0.0156126300301219212217568535995825232086423550686814635293876744035364259647929167763641353639085929285192729729570945658304937255929114458885296622493040224"
e[6] = big"0.00890598154557403928205152521539967562877335780940124672915181111908317890891659158654221736499522823959933517986673010006749138291836676520080172845444352328"
e[7] = big"-0.00514824122639241252178399021479378841872099572255461304439292434131750195489022869965968028106854978547414579491205935930595041763060069987112580994637398395"
e[8] = big"0.00296533914055503317169967748114188676589522458557982039693426239853498956125735811263087631479968309978854200615027412311940897061471388689986239742919640848"
e[9] = big"-0.0010634368308888065260482548541946175520274736959410047497431569257848032902381738362547705844630238841535652230832162703806430112125115777122361837311714267"
e[1] = big"-89.8315397040376845865027298766511166861131537901479318008187013574099993398844876573472315778350373191126204142357525815115482293843777624541394691345885716"
e[2] = big"11.4742766094687721590222610299234578063148408248968597722844661019124491691448775794163842022854672278004372474682761156236829237591471118886342174262239472"
e[3] = big"-3.81419058476042873698615187248837320040477891376179026064712181641592908409919668221598902628694008903410444392769866137859041139561191341971835412426311966"
e[4] = big"1.81155300867853110911564243387531599775142729190474576183505286509346678884073482369609308584446518479366940471952219053256362416491879701351428578466580598"
e[5] = big"-1.03663781378817415276482837566889343026914084945266083480559060702535168750966084568642219911350874500410428043808038021858812311835772945467924877281164517"
e[6] = big"0.660865688193716483757690045578935452512421753840843511309717716369201467579470723336314286637650332622546110594223451602017981477424498704954672224534648119"
e[7] = big"-0.444189256280526730087023435911479370800996444567516110958885112499737452734669537494435549195615660656770091500773942469075264796140815048410568498349675229"
e[8] = big"0.290973163636905565556251162453264542120491238398561072912173321087011249774042707406397888774630179702057578431394918930648610404108923880955576205699885598"
e[9] = big"-0.111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111222795"
elseif (num_stages == 11)
e = Vector{BigFloat}(undef, 11)
e[1] = big"-134.152626015465044063378550835075318643291579891352838474367124350171545245813244797505763447327562609902792066283575334085390478517120485782603677022267543"
e[2] = big"17.0660253399060146849212356299749772423073416838121578997449942694355150369717420038613850964748566731121793290881077515821557030349184664685171028112845693"
e[3] = big"-5.63464089555106294823267450977601185069165875295372865523759287935369597689662768988715406731927279137711764532851201746616033935275093116699140897901326857"
e[4] = big"2.65398285960564394428637524662555134392389271086844331137910389226095922845489762567700560496915255196379049844894623384211693438658842276927416827629120392"
e[5] = big"-1.50753272514563441873424939425410006034401178578882643601844794171149654717227697249290904230103304153661631200445957060050895700394738491883951084826421405"
e[6] = big"0.960260572218344245935269463733859188992760928707230734981795807797858324380878500135029848170473080912207529262984056182004711806457345405466997261506487216"
e[7] = big"-0.658533932484491373507110339620843007350146695468297825313721271556868110859353953892288534787571420691760379406525738632649863532050280264983313133523641674"
e[8] = big"0.47189364490739958527881800092758816959227958959727295348380187162217987951960275929676019062173412149363239153353720640122975284789262792027244826613784432"
e[9] = big"-0.34181016557091711933253384050957887606039737751222218385118573305954222606860932803075900338195356026497059819558648780544900376040113065955083806288937526"
e[10] = big"0.233890408488838371854329668882967402012428680999899584289285425645726546573900943747784263972086087200538161975992991491742449181322441138528940521648041699"
e[11] = big"-0.0909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909090909093788951"
elseif (num_stages == 13)
e = Vector{BigFloat}(undef, 13)
e[1] = big"-187.337806666035250696387113105488477375830948862159770885826492736743460038872636916422100706359786154665214547894636085276885830138994748219148357620227002"
e[2] = big"23.775705048946302520021716862887025159493544949407763131913924588605891085865877529749667170060976683489861224477421212170329019074926368036881685518012728"
e[3] = big"-7.81823724708755833325842676798052630403951326380926053607036280237871312516353176794790424805918285990907426633641064901501063343970205708057561515795364672"
e[4] = big"3.66289388251066047904501665386587373682645522696191680651425553890800106379174431775463608296821504040006089759980653462003322200870566661322334735061646223"
e[5] = big"-2.06847094952801462392548700163367193433237251061765813625197254100990426184032443671875204952150187523419743001493620194301209589692419776688692360679336566"
e[6] = big"1.31105635982993157063104433803023633257356281733787535204132865785504258558244947718491624714070193102812968996631302993877989767202703509685785407541965509"
e[7] = big"-0.897988270828178667954874573865888835427640297795141000639881363403080887358272161865529150995401606679722232843051402663087372891040498351714982629218397165"
e[8] = big"0.648958340079591709325028357505725843500310779765000237611355105578356380892509437805732950287939403489669590070670546599339082534053791877148407548785389408"
e[9] = big"-0.485906120880156534303797908584178831869407602334908394589833216071089678420073112977712585616439120156658051446412515753614726507868506301824972455936531663"
e[10] = big"0.370151313405058266144090771980402238126294149688261261935258556082315591034906662511634673912342573394958760869036835172495369190026354174118335052418701339"
e[11] = big"-0.27934271062931554435643589252670994638477019847143394253283050767117135003630906657393675748475838251860910095199485920686192935009874559019443503474805827"
e[12] = big"0.195910097140006778096161342733266840441407888950433028972173797170889557600583114422425296743817444283872389581116632280572920821812614435192580036549169031"
e[13] = big"-0.0769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769230769254590189"
else
p = num_stages
eb = variables(:b, 1:num_stages + 1)
@variables y
zz = zeros(size(a, 1) + 1)
zz2 = zeros(size(a, 1))
eA = [zz'
zz2 a]
ec = [0; c]
constraints = map(Iterators.flatten(RootedTreeIterator(i) for i in 1:2*p-3)) do t
residual_order_condition(t, RungeKuttaMethod(eA, eb, ec))
e_sym = variables(:e, 1:num_stages)
constraints = map(Iterators.flatten(RootedTreeIterator(i) for i in 1:num_stages)) do t
residual_order_condition(t, RungeKuttaMethod(a, e_sym, c))
end
AA, bb, islinear = Symbolics.linear_expansion(Symbolics.substitute.(constraints, (eb[1]=>1/γ,)), eb[2:end])
AA = Float64.(map(unwrap, AA))
idxs = qr(AA', ColumnNorm()).p[1:num_stages]
@assert rank(AA[idxs, :]) == num_stages
@assert islinear
b_hat = Symbolics.expand.((AA \ -bb))
e = [Symbolics.symbolic_to_float(b_hat[i] - b[i]) for i in 1 : num_stages]
AA, bb, islinear = Symbolics.linear_expansion(constraints, e_sym[1:end])
AA = BigFloat.(map(unwrap, AA))
bb = BigFloat.(map(unwrap, bb))
A = vcat([zeros(num_stages -1); 1]', AA)
b_2 = vcat(-1/big(num_stages), -(num_stages)^2, -1, zeros(size(A, 1) - 3))
e = A \ b_2
end
RadauIIATableau{T1, T2}(T, TI, c, γ, α, β, e)
end

0 comments on commit 7afc20c

Please sign in to comment.