Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
rveltz committed Oct 6, 2024
2 parents 2e64b72 + c773565 commit 6e0bc59
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 103 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -29,14 +28,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[weakdeps]
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[extensions]
GLMakieExt = "GLMakie"
MakieExt = "Makie"
JLD2Ext = "JLD2"
PlotsExt = "Plots"
PlotsExt = ["Plots", "RecipesBase"]

[compat]
Accessors = "0.1"
Expand All @@ -47,7 +47,7 @@ DataStructures = "0.17, 0.18"
DocStringExtensions = "^0.8, ^0.9"
FastGaussQuadrature = "^0.4, ^0.5, 1"
ForwardDiff = "^0.10"
GLMakie = "0.10"
Makie = "^0.21"
IterativeSolvers = "0.8.4, 0.8.5, ^0.9"
JLD2 = "0.4, 0.5"
KrylovKit = "^0.7, ^0.8"
Expand Down
2 changes: 1 addition & 1 deletion examples/SH3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const BK = BifurcationKit

Makie.inline!(true)

contour3dMakie(x; k...) = GLMakie.contour(x; k...)
contour3dMakie(x; k...) = Makie.contour(x; k...)
contour3dMakie(x::AbstractVector; k...) = contour3dMakie(reshape(x,Nx,Ny,Nz); k...)
contour3dMakie(ax, x; k...) = (contour(ax, x; k...))
contour3dMakie(ax, x::AbstractVector; k...) = contour3dMakie(ax, reshape(x,Nx,Ny,Nz); k...)
Expand Down
9 changes: 5 additions & 4 deletions ext/GLMakieExt/GLMakieExt.jl → ext/MakieExt/MakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GLMakieExt
using GLMakie, BifurcationKit
module MakieExt
using Makie, BifurcationKit
import BifurcationKit: _plot_backend,
plot,
plot,
plot!,
hasbranch,
plot_branch_cont,
Expand All @@ -22,13 +22,14 @@ module GLMakieExt
get_color,
colorbif,
get_plot_backend,
set_plot_backend!,
BK_Makie,
plotAllDCBranch,
plot_DCont_branch
include("plot.jl")

function __init__()
_plot_backend[] = BK_Makie()
set_plot_backend!(BK_Makie())
return nothing
end
end
174 changes: 84 additions & 90 deletions ext/GLMakieExt/plot.jl → ext/MakieExt/plot.jl
Original file line number Diff line number Diff line change
@@ -1,141 +1,135 @@
using GLMakie: Point2f0
using Makie: Point2f0

function GLMakie.convert_arguments(::PointBased, contres::AbstractBranchResult, vars = nothing, applytoY = identity, applytoX = identity)
function Makie.convert_arguments(::PointBased, contres::AbstractBranchResult, vars = nothing, applytoY = identity, applytoX = identity)
ind1, ind2 = get_plot_vars(contres, vars)
return ([Point2f0(i, j) for (i, j) in zip(map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)))],)
end

function plot!(ax1, contres::AbstractBranchResult;
plotfold = false,
plotstability = true,
plotspecialpoints = true,
putspecialptlegend = true,
filterspecialpoints = false,
vars = nothing,
linewidthunstable = 1.0,
linewidthstable = 3.0linewidthunstable,
plotcirclesbif = true,
branchlabel = "",
applytoY = identity,
applytoX = identity)

function isplit(x::AbstractVector{T}, indices::AbstractVector{<:Integer}, splitval::Bool = true) where {T<:Real}
# Adapt behavior for CairoMakie only
if !isempty(indices) && isdefined(Main, :CairoMakie) && Makie.current_backend() == Main.CairoMakie
xx = similar(x, length(x) + 2 * (length(indices)))
for (i, ind) in enumerate(indices)
if ind == first(indices)
xx[1:ind] .= @views x[1:ind]
else
xx[(2*(i-1)).+(indices[i-1]+1:ind)] .= @views x[(indices[i-1]+1:ind)]
end
if !splitval
xx[2*(i-1)+ind] = x[ind-1]
end
# Add a NaN is necessary, otherwise continue with same value as before (useful for linewidth)
xx[2*(i-1)+ind+1] = splitval ? NaN : x[ind-1]
# Repeat last value before NaN, but adapt for linewidth
xx[2*(i-1)+ind+2] = splitval ? x[ind] : x[ind+1]
end
# Fill the rest of the extended array
xx[last(indices)+2*length(indices)+1:end] .= @views x[last(indices)+1:end]
return xx
else
return x
end
end

function plot!(ax1, contres::AbstractBranchResult; plotfold = false, plotstability = true, plotspecialpoints = true, putspecialptlegend = true, filterspecialpoints = false, vars = nothing, linewidthunstable = 1.0, linewidthstable = 3.0linewidthunstable, plotcirclesbif = true, branchlabel = nothing, applytoY = identity, applytoX = identity)

# names for axis labels
ind1, ind2 = get_plot_vars(contres, vars)
xlab, ylab = get_axis_labels(ind1, ind2, contres)

# stability linewidth
linewidth = linewidthunstable
indices = [sp.idx for sp in contres.specialpoint if sp.type !== :endpoint]
# isplit required to work with CairoMakie due to change of linewidth for stability
if _hasstability(contres) && plotstability
linewidth = map(x -> isodd(x) ? linewidthstable : linewidthunstable, contres.stable)
end
if branchlabel == ""
lines!(ax1, map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)); linewidth)
else
lines!(ax1, map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)), linewidth = linewidth, label = branchlabel)
linewidth = isplit(map(x -> x ? linewidthstable : linewidthunstable, contres.stable), indices, false)
end
xbranch = isplit(map(applytoX, getproperty(contres.branch, ind1)), indices)
ybranch = isplit(map(applytoY, getproperty(contres.branch, ind2)), indices)
lines!(ax1, xbranch, ybranch, linewidth = linewidth, label = branchlabel)
ax1.xlabel = xlab
ax1.ylabel = ylab

# display bifurcation points
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)-1), contres.specialpoint)
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres) - 1), contres.specialpoint)
if length(bifpt) >= 1 && plotspecialpoints #&& (ind1 == :param)
if filterspecialpoints == true
bifpt = filterBifurcations(bifpt)
end
scatter!(ax1,
[applytoX(getproperty(contres[pt.idx], ind1)) for pt in bifpt],
[applytoY(getproperty(contres[pt.idx], ind2)) for pt in bifpt];
marker = map(x -> (x.status == :guess) && (plotcirclesbif==false) ? :rect : :circle, bifpt),
markersize = 10,
color = map(x -> get_color(x.type), bifpt),
)
end

scatter!(ax1, [applytoX(getproperty(contres[pt.idx], ind1)) for pt in bifpt], [applytoY(getproperty(contres[pt.idx], ind2)) for pt in bifpt]; marker = map(x -> (x.status == :guess) && (plotcirclesbif == false) ? :rect : :circle, bifpt), markersize = 10, color = map(x -> get_color(x.type), bifpt))
end

# add legend for bifurcation points
if putspecialptlegend && length(bifpt) >= 1
bps = unique(x -> x.type, [pt for pt in bifpt if (pt.type != :none && (plotfold || pt.type != :fold))])
(length(bps) == 0) && return
for pt in bps
scatter!(ax1,
[applytoX(getproperty(contres[pt.idx], ind1))],
[applytoY(getproperty(contres[pt.idx], ind2))];
color = get_color(pt.type),
markersize = 10,
label = "$(pt.type)")
scatter!(ax1, [applytoX(getproperty(contres[pt.idx], ind1))], [applytoY(getproperty(contres[pt.idx], ind2))]; color = get_color(pt.type), markersize = 10, label = "$(pt.type)")
end
GLMakie.axislegend(ax1, merge = true, unique = true)
Makie.axislegend(ax1, merge = true, unique = true)
end
ax1
end

function plot_branch_cont(contres::ContResult,
state,
iter,
plotuserfunction;
plotfold = false,
plotstability = true,
plotspecialpoints = true,
putspecialptlegend = true,
filterspecialpoints = false,
linewidthunstable = 1.0,
linewidthstable = 3.0linewidthunstable,
plotcirclesbif = true,
applytoY = identity,
applytoX = identity)
function plot_branch_cont(contres::ContResult, state, iter, plotuserfunction; plotfold = false, plotstability = true, plotspecialpoints = true, putspecialptlegend = true, filterspecialpoints = false, linewidthunstable = 1.0, linewidthstable = 3.0linewidthunstable, plotcirclesbif = true, applytoY = identity, applytoX = identity)
sol = getsolution(state)
if length(contres) == 0; return ; end

if length(contres) == 0
return
end

# names for axis labels
ind1, ind2 = get_plot_vars(contres, nothing)
xlab, ylab = get_axis_labels(ind1, ind2, contres)

# stability linewidth
linewidth = linewidthunstable
if _hasstability(contres) && plotstability
linewidth = map(x -> isodd(x) ? linewidthstable : linewidthunstable, contres.stable)
end

fig = Figure(size = (1200, 700))
ax1 = fig[1:2, 1] = Axis(fig, xlabel = String(xlab), ylabel = String(ylab), tellheight = true)

ax2 = fig[1, 2] = Axis(fig, xlabel = "step [$(state.step)]", ylabel = String(xlab))
lines!(ax2, contres.step, contres.param, linewidth = linewidth)

if compute_eigenelements(iter)
eigvals = contres.eig[end].eigenvals
ax_ev = fig[3, 1:2] = Axis(fig, xlabel = "", ylabel = "")
scatter!(ax_ev, real.(eigvals), imag.(eigvals), strokewidth = 0, markersize = 10, color = :black)
# add stability boundary
maxIm = maximum(imag, eigvals)
minIm = minimum(imag, eigvals)
if maxIm-minIm < 1e-6
if maxIm - minIm < 1e-6
maxIm, minIm = 1, -1
end
lines!(ax_ev, [0, 0], [maxIm, minIm], color = :blue, linewidth = linewidthunstable)
end

# plot arrow to indicate the order of computation
if length(contres) > 1
x = contres.branch[end].param
y = getproperty(contres.branch,1)[end]
y = getproperty(contres.branch, 1)[end]
u = contres.branch[end].param - contres.branch[end-1].param
v = getproperty(contres.branch,1)[end] - getproperty(contres.branch,1)[end-1]
GLMakie.arrows!(ax1, [x], [y], [u], [v], color = :green, arrowsize = 20,)
v = getproperty(contres.branch, 1)[end] - getproperty(contres.branch, 1)[end-1]
Makie.arrows!(ax1, [x], [y], [u], [v], color = :green, arrowsize = 20)
end

plot!(ax1, contres; plotfold, plotstability, plotspecialpoints, putspecialptlegend, filterspecialpoints, linewidthunstable, linewidthstable, plotcirclesbif, applytoY, applytoX)

if isnothing(plotuserfunction) == false
ax_perso = fig[2, 2] = Axis(fig, tellheight = true)
plotuserfunction(ax_perso, sol.u, sol.p; ax1 = ax1)
end

display(fig)
fig
end

function plot(contres::AbstractBranchResult; kP...)
if length(contres) == 0; return ;end
if length(contres) == 0
return
end

ind1, ind2 = get_plot_vars(contres, nothing)
xlab, ylab = get_axis_labels(ind1, ind2, contres)
Expand All @@ -150,17 +144,17 @@ end

plot(brdc::DCResult; kP...) = plot(brdc.branches...; kP...)

function plot(brs::AbstractBranchResult...;
branchlabel = ["$i" for i=1:length(brs)],
kP...)
if length(brs) == 0; return ;end
function plot(brs::AbstractBranchResult...; branchlabel = ["$i" for i = 1:length(brs)], kP...)
if length(brs) == 0
return
end
fig = Figure()
ax1 = fig[1, 1] = Axis(fig)

for (id, contres) in pairs(brs)
plot!(ax1, contres; branchlabel = branchlabel[id], kP...)
end
GLMakie.axislegend(ax1, merge = true, unique = true)
Makie.axislegend(ax1, merge = true, unique = true)
display(fig)
fig, ax1
end
Expand All @@ -186,14 +180,14 @@ function plot_periodic_potrap(outpof, n, M; ratio = 2)
@assert ratio > 0 "You need at least one component"
outpo = reshape(outpof[1:end-1], ratio * n, M)
if ratio == 1
heatmap(outpo[1:n,:]', ylabel="Time", color=:viridis)
heatmap(outpo[1:n, :]', ylabel = "Time", color = :viridis)
else
fig = GLMakie.Figure()
ax1 = Axis(fig[1,1], ylabel="Time")
ax2 = Axis(fig[1,2], ylabel="Time")
# GLMakie.heatmap!(ax1, rand(2,2))
GLMakie.heatmap!(ax1, outpo[1:n,:]')
GLMakie.heatmap!(ax2, outpo[n+2:end,:]')
fig = Makie.Figure()
ax1 = Axis(fig[1, 1], ylabel = "Time")
ax2 = Axis(fig[1, 2], ylabel = "Time")
# Makie.heatmap!(ax1, rand(2,2))
Makie.heatmap!(ax1, outpo[1:n, :]')
Makie.heatmap!(ax2, outpo[n+2:end, :]')
fig
end
end
Expand All @@ -211,7 +205,9 @@ end
####################################################################################################
# plot recipes for the bifurcation diagram
function plot(bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
if ~hasbranch(bd); return; end
if ~hasbranch(bd)
return
end

fig = Figure()
ax = fig[1, 1] = Axis(fig)
Expand All @@ -223,7 +219,9 @@ function plot(bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
end

function _plot_bifdiag_makie!(ax, bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
if ~hasbranch(bd); return; end
if ~hasbranch(bd)
return
end

_bd = get_branch(bd, code)
_plot_bifdiag_makie!(ax, _bd.child; code = (), level = level, k...)
Expand All @@ -236,16 +234,12 @@ end

function _plot_bifdiag_makie!(ax, bd::Vector{BifDiagNode}; code = (), level = (-Inf, Inf), k...)
for b in bd
_plot_bifdiag_makie!(ax, b; code, level, k... )
_plot_bifdiag_makie!(ax, b; code, level, k...)
end
end
####################################################################################################
plotAllDCBranch(branches) = plot(branches...)

function plot_DCont_branch(::BK_Makie,
branches,
nbrs::Int,
nactive::Int,
nstep::Int)
function plot_DCont_branch(::BK_Makie, branches, nbrs::Int, nactive::Int, nstep::Int)
plot(branches...)
end
6 changes: 4 additions & 2 deletions ext/PlotsExt/PlotsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module PlotsExt
using Plots, BifurcationKit
import BifurcationKit: _plot_backend,
plot_branch_cont,
plot_branch_cont,
plot_periodic_potrap,
plot_periodic_shooting!,
plot_periodic_shooting,
Expand All @@ -18,6 +18,8 @@ module PlotsExt
filter_bifurcations,
get_color,
AbstractResult,
get_plot_backend,
set_plot_backend!,
BK_NoPlot, BK_Plots,
plotAllDCBranch,
plot_DCont_branch,
Expand All @@ -28,7 +30,7 @@ module PlotsExt
include("plot.jl")

function __init__()
_plot_backend[] = BK_Plots()
set_plot_backend!(BK_Plots())
return nothing
end
end
2 changes: 1 addition & 1 deletion src/BifurcationKit.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module BifurcationKit
using Printf, Dates, LinearMaps, BlockArrays, RecipesBase, StructArrays
using Printf, Dates, LinearMaps, BlockArrays, StructArrays
using Reexport
@reexport using Accessors: setproperties, @set, @reset, PropertyLens, getall, set, @optic, IndexLens, ComposedOptic
using Parameters: @with_kw, @unpack, @with_kw_noshow
Expand Down
Loading

0 comments on commit 6e0bc59

Please sign in to comment.