Skip to content

Commit

Permalink
Merge pull request #81 from TidierOrg/legend
Browse files Browse the repository at this point in the history
fake legend works
  • Loading branch information
rdboyes authored Apr 17, 2024
2 parents 9ccbb22 + 622ac22 commit c76ab6d
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 55 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ docs/src/examples/generated/
Manifest.toml

# IDE Specific files
.vscode
.vscode

# temp files
/scratch/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Makie Themes:

Colour Scales:

- `scale_color_manual()` - arguments should be given directly in order, accepts anything that can be parsed as a color by Colors.jl (named colors, hex values, etc.)
- `scale_color_manual()` - set `values = c(c1, c2, c3, ...)`, accepts anything that can be parsed as a color by Colors.jl (named colors, hex values, etc.)
- `scale_color_[discrete|continuous|binned]()` - set `palette =` a [ColorSchemes.jl palette](https://juliagraphics.github.io/ColorSchemes.jl/stable/catalogue/) as a string or symbol. Also accepts ColorSchemes.jl color scheme objects.

Additional Elements:
Expand Down
6 changes: 4 additions & 2 deletions src/TidierPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@ include("structs.jl")

include("attributes.jl")
include("addplots.jl")
include("aes.jl")
include("aes_ops.jl")
include("aes.jl")
include("attributes.jl")
include("draw.jl")
include("extract_aes.jl")
#include("facets.jl")
include("geom.jl")
include("ggplot.jl")
include("ggsave.jl")
include("labs.jl")
include("legend.jl")
include("patchwork.jl")
include("scales_colour.jl")
include("scales_numeric.jl")
include("themes.jl")
include("transforms.jl")
include("patchwork.jl")
include("show.jl")
include("util.jl")

Expand Down
4 changes: 3 additions & 1 deletion src/addplots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ function Base.:+(x::GGPlot, y...)::GGPlot
[i.opt for i in y if i isa AxisOptions]...),
theme,
merge(x.column_transformations,
[i.column_transformations for i in y if i isa AxisOptions]...)
[i.column_transformations for i in y if i isa AxisOptions]...),
merge(x.legend_options,
[i.legend_options for i in y if i isa AxisOptions]...)
)

return result
Expand Down
45 changes: 45 additions & 0 deletions src/attributes.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@
const _legend_geom_elements = Dict{String, DataType}(
"geom_bar" => MarkerElement,
"geom_col" => MarkerElement,
"geom_histogram" => MarkerElement,
"geom_point" => MarkerElement,
"geom_path" => LineElement,
"geom_line" => LineElement,
"geom_step" => LineElement,
"geom_smooth" => LineElement,
"geom_errorbar" => LineElement,
"geom_errorbarh" => LineElement,
"geom_violin" => MarkerElement,
"geom_boxplot" => MarkerElement,
"geom_contour" => LineElement,
"geom_tile" => MarkerElement,
"geom_text" => MarkerElement,
"geom_label" => MarkerElement,
"geom_density" => MarkerElement,
"geom_hline" => LineElement,
"geom_vline" => LineElement
);

const _legend_geom_symbols = Dict{String, Dict}(
"geom_bar" => Dict(:marker => :rect, :markersize => 12),
"geom_col" => Dict(:marker => :rect, :markersize => 12),
"geom_histogram" => Dict(:marker => :rect, :markersize => 12),
"geom_point" => Dict(:marker => :circle, :markersize => 12),
"geom_path" => Dict(:linestyle => nothing),
"geom_line" => Dict(:linestyle => nothing),
"geom_step" => Dict(:linestyle => nothing),
"geom_smooth" => Dict(:linestyle => nothing),
"geom_errorbar" => Dict(:linestyle => nothing),
"geom_errorbarh" => Dict(:linestyle => nothing),
"geom_violin" => Dict(:marker => :rect, :markersize => 12),
"geom_boxplot" => Dict(:marker => :rect, :markersize => 12),
"geom_contour" => Dict(:linestyle => nothing),
"geom_tile" => Dict(:marker => :rect, :markersize => 12),
"geom_text" => Dict(:marker => :x, :markersize => 12),
"geom_label" => Dict(:marker => :x, :markersize => 12),
"geom_density" => Dict(:marker => :rect, :markersize => 12),
"geom_hline" => Dict(:linestyle => nothing),
"geom_vline" => Dict(:linestyle => nothing)
);

const _ggplot_to_makie = Dict{String, String}(
"colour" => "color",
"shape" => "marker",
Expand Down Expand Up @@ -99,3 +143,4 @@ const _makie_expected_type = Dict{String, Type}(
"ymin" => Real,
"ymax" => Real,
);

17 changes: 13 additions & 4 deletions src/draw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,21 @@ end

function draw_ggplot(plot::GGPlot)
axis = Makie.SpecApi.Axis(plot)
legend = build_legend(plot)

Makie.plot(
Makie.SpecApi.GridLayout(
axis
if isnothing(legend)
Makie.plot(
Makie.SpecApi.GridLayout(
axis
)
)
)
else
Makie.plot(
Makie.SpecApi.GridLayout(
[axis legend]
)
)
end
end

function draw_ggplot(plot_grid::GGPlotGrid)
Expand Down
2 changes: 2 additions & 0 deletions src/ggplot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ function ggplot(args...; kwargs...)
plot_data,
args_dict,
theme_ggplot2(),
Dict(),
Dict())
end

Expand All @@ -30,5 +31,6 @@ function ggplot(data::DataFrame, args...; kwargs...)
data,
args_dict,
theme_ggplot2(),
Dict(),
Dict())
end
4 changes: 2 additions & 2 deletions src/labs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function labs(args...; kwargs...)
"subtitle" => :subtitle
)

return AxisOptions(Dict(ggplot_to_makie[k] => v for (k, v) in args_dict), Dict())
return AxisOptions(Dict(ggplot_to_makie[k] => v for (k, v) in args_dict), Dict(), Dict())
end

function labs(plot::GGPlot, args...; kwargs...)
Expand Down Expand Up @@ -46,7 +46,7 @@ function lims(args...; kwargs...)
end
end

return AxisOptions(lims_dict, Dict())
return AxisOptions(lims_dict, Dict(), Dict())
end

function lims(plot::GGPlot, args...; kwargs...)
Expand Down
97 changes: 97 additions & 0 deletions src/legend.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
function build_legend(plot::GGPlot)
# what aes will automatically have a legend built for them?
auto_legend = [:color]

palette_function = nothing

if haskey(plot.column_transformations, :color)
palette_function = plot.column_transformations[:color][2]
elseif !(any([haskey(geom.aes, "color") || haskey(geom.aes, "color") for geom in plot.geoms]) ||
haskey(plot.default_aes, "color") || haskey(plot.default_aes, "color"))

return nothing
end

legend = DataFrame(labels = String[], colors = Any[], options = Any[], element = Any[])
title = nothing

colorbar_kwargs = Dict()
lowlim = Inf
highlim = -Inf
colorbar = false

for geom in plot.geoms

all_aes = merge(plot.default_aes, geom.aes)

color_colname = haskey(all_aes, "colour") ? all_aes["colour"] :
haskey(all_aes, "color") ? all_aes["color"] :
nothing

plot_data = isnothing(geom.data) ? plot.data : geom.data

if isnothing(palette_function)
if eltype(plot_data[!, color_colname]) <: Union{AbstractString, AbstractChar, CategoricalValue}
plot = plot + scale_colour_manual(values = c(
RGB(0/255, 114/255, 178/255), # blue
RGB(230/255, 159/255, 0/255), # orange
RGB(0/255, 158/255, 115/255), # green
RGB(204/255, 121/255, 167/255), # reddish purple
RGB(86/255, 180/255, 233/255), # sky blue
RGB(213/255, 94/255, 0/255), # vermillion
RGB(240/255, 228/255, 66/255))) # yellow)
else
plot = plot + scale_colour_continuous(palette = :viridis)
end
palette_function = plot.column_transformations[:color][2]
end

if !isnothing(color_colname) && plot.legend_options[:color][:type] in ["manual", "discrete"]

plottable_data = palette_function(:color, [color_colname], plot_data)
labels = unique(plottable_data[:color].raw)

append!(legend, sort(DataFrame(labels = labels,
colors = unique(plottable_data[:color].makie_function(plottable_data[:color].raw)),
options = _legend_geom_symbols[geom.args["geom_name"]],
element = _legend_geom_elements[geom.args["geom_name"]]),
:labels))

title = get(plot.legend_options[:color], :name, titlecase(string(color_colname)))
end

if !isnothing(color_colname) && plot.legend_options[:color][:type] in ["continuous", "binned"]

plottable_data = palette_function(:color, [color_colname], plot_data)

colorbar_kwargs[:colormap] = plot.legend_options[:color][:type] == "continuous" ? Symbol(plot.legend_options[:color][:palette]) :
cgrad(Symbol(plot.legend_options[:color][:palette]), 5, categorical = true)

lowlim = min(minimum(plottable_data[:color].raw), lowlim)
highlim = max(maximum(plottable_data[:color].raw), highlim)

colorbar = true
title = get(plot.legend_options[:color], :name, titlecase(string(color_colname)))
end
end

#return legend

if nrow(legend) != 0
labels = String[]
elems = Any[]

for (k, v) in pairs(groupby(legend, :labels))
push!(elems, [l.element(color = l.colors; l.options...) for l in eachrow(v)])
push!(labels, string(v.labels[1]))
end

return Makie.SpecApi.Legend(elems, labels, title)
end

if (colorbar)
return Makie.SpecApi.Colorbar(;colorbar_kwargs..., limits = (lowlim, highlim), label = title)
end

return nothing
end
32 changes: 10 additions & 22 deletions src/scales_colour.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ function make_color_lookup_discrete(args_dict)
end

scheme = palette isa Symbol ? ColorSchemes.colorschemes[palette] :
palette isa String ? ColorSchemes.colorschemes[Symbol(palette)] :
palette isa ColorScheme ? palette : nothing
palette isa String ? ColorSchemes.colorschemes[Symbol(palette)] :
palette isa ColorScheme ? palette : nothing

if isnothing(scheme)
palette_type = typeof(palette)
Expand Down Expand Up @@ -90,7 +90,7 @@ function color_scale_to_ggoptions(args_dict::Dict)
function color_transform_fn(target::Symbol, source::Vector{Symbol}, data::DataFrame)
input = data[!, source[1]]

if typeof(input) <: Union{AbstractVector{String}, AbstractVector{Char}, CategoricalArray}
if eltype(input) <: Union{AbstractString, AbstractChar, CategoricalValue}

cat_array = CategoricalArray(input)

Expand All @@ -102,7 +102,7 @@ function color_scale_to_ggoptions(args_dict::Dict)
nothing
)
)
elseif typeof(input) <: Union{Vector{Int}, Vector{Float64}, Vector{Float32}}
elseif eltype(input) <: Union{Integer, AbstractFloat}
return Dict{Symbol, PlottableData}(
target => PlottableData(
input,
Expand All @@ -111,23 +111,10 @@ function color_scale_to_ggoptions(args_dict::Dict)
nothing
)
)
else # try to parse whatever it is as an int, error if not successful
try
int_array = parse.(Int, input)
catch
scale = args_dict[:scale]
@error "Column is not compatible with scale: $scale"
end

return Dict{Symbol, PlottableData}(
target => PlottableData(
int_array,
x -> lookup(x),
nothing,
nothing
)
)
end
else
scale = args_dict[:scale]
throw(@error "Column is not compatible with scale: $scale")
end
end
return color_transform_fn
end
Expand All @@ -136,7 +123,8 @@ function color_scale_to_ggoptions(args_dict::Dict)

return AxisOptions(
Dict(),
Dict(Symbol(args_dict[:scale]) => [Symbol(args_dict[:scale])]=>color_transform)
Dict(Symbol(args_dict[:scale]) => [Symbol(args_dict[:scale])]=>color_transform),
Dict(:color => args_dict) # pass the full args dict for use by legend
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/scales_numeric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function continuous_scale_to_ggoptions(args_dict::Dict)
end

return AxisOptions(
Dict(Symbol(k) => v for (k, v) in options_dict), Dict()
Dict(Symbol(k) => v for (k, v) in options_dict), Dict(), Dict()
)

end
Expand Down
2 changes: 2 additions & 0 deletions src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct GGPlot
axis_options::Dict
theme::Attributes
column_transformations::Dict
legend_options::Dict
end

struct Aesthetics
Expand All @@ -28,6 +29,7 @@ end
struct AxisOptions
opt::Dict{Symbol, Any}
column_transformations::Dict
legend_options::Dict
end

struct GGPlotGrid
Expand Down
11 changes: 5 additions & 6 deletions test/test_geoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@
RGB(240/255, 228/255, 66/255), # yellow
]

t = ggplot(penguins) +
geom_violin(aes(x = :species, y = :bill_length_mm, color = :species))

m = Makie.plot(
Makie.SpecApi.GridLayout(
Makie.SpecApi.Axis(
[Makie.SpecApi.Axis(
plots = [
Makie.PlotSpec(
:Violin,
Expand All @@ -157,13 +160,9 @@
color = (x -> colours[x]).(levelcode.(cat_array))
)]; xticks = (unique(levelcode.(cat_array)),
unique(cat_array))
)
)
) TidierPlots.build_legend(t)])
)

t = ggplot(penguins) +
geom_violin(aes(x = :species, y = :bill_length_mm, color = :species))

@test plot_images_equal(t, m)

end
Expand Down
Loading

0 comments on commit c76ab6d

Please sign in to comment.