From 118fb0b9d8b12f9f010ffe89d6198eea4371b2d5 Mon Sep 17 00:00:00 2001 From: Patrick Kofod Mogensen Date: Sat, 19 Dec 2020 14:48:16 +0100 Subject: [PATCH] Add tools to extract x, F, and J from the trace. --- src/NLsolve.jl | 2 ++ src/trace_tools.jl | 16 ++++++++++++++++ test/complex.jl | 3 +++ 3 files changed, 21 insertions(+) create mode 100644 src/trace_tools.jl diff --git a/src/NLsolve.jl b/src/NLsolve.jl index 0a63056..fd3c614 100644 --- a/src/NLsolve.jl +++ b/src/NLsolve.jl @@ -52,4 +52,6 @@ include("nlsolve/nlsolve.jl") include("nlsolve/utils.jl") include("nlsolve/fixedpoint.jl") +include("trace_tools.jl") + end # module diff --git a/src/trace_tools.jl b/src/trace_tools.jl new file mode 100644 index 0000000..b80f5fc --- /dev/null +++ b/src/trace_tools.jl @@ -0,0 +1,16 @@ +trace(r::SolverResults) = r.trace +function x_trace(r::SolverResults) + tr = trace(r).states + !haskey(tr[1].metadata, "x") && error("Trace does not contain x. To get a trace of x, run nlsolve() with extended_trace = true") + [ state.metadata["x"] for state in tr ] +end +function F_trace(r::SolverResults) + tr = trace(r).states + !haskey(tr[1].metadata, "f(x)") && error("Trace does not contain F. To get a trace of the residuals, run nlsolve() with extended_trace = true") + [ state.metadata["f(x)"] for state in tr ] +end +function J_trace(r::SolverResults) + tr = trace(r).states + !haskey(tr[1].metadata, "g(x)") && error("Trace does not contain J. To get a trace of the Jacobian, run nlsolve() with extended_trace = true") + [ state.metadata["g(x)"] for state in tr ] +end \ No newline at end of file diff --git a/test/complex.jl b/test/complex.jl index a0224b2..1fa8839 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -32,5 +32,8 @@ for linesearches in [BackTracking(),StrongWolfe(),HagerZhang(),MoreThuente()] #S @test sol.g_calls == sol_real.g_calls @test all(sol_real.trace[i].stepnorm == sol_real.trace[i].stepnorm for i in 2:sol.iterations) @test all(norm(sol.trace[i].metadata["f(x)"]) ≈ norm(sol_real.trace[i].metadata["f(x)"]) for i in 1:5) + NLsolve.x_trace(sol_real) + NLsolve.F_trace(sol_real) + NLsolve.J_trace(sol_real) end end