Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 'mmstd', 'plot_fit'; add tests #9

Merged
merged 2 commits into from
Jul 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/GPR/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ GPR.learn!(gprw) # fit GPR with "Const * RBF + White" kernel
#GPR.learn!(gprw, kernel = "matern") # "rbf" and "matern" are supported for now
#GPR.learn!(gprw, kernel = "matern", nu = 1) # Matern's parameter nu; 1.5 by def

mesh = minimum(gpr_data, dims=1)[1] : 0.01 : maximum(gpr_data, dims=1)[1]
mesh, mean, std = GPR.mmstd(gprw) # equispaced mesh; mean and std at mesh points
#mesh, mean, std = GPR.mmstd(gprw, mesh_n = 11) # by default, `mesh_n` is 1001

mean, std = GPR.predict(gprw, mesh, return_std = true)
#mean, std = GPR.predict(gprw, mesh, return_std = true) # your own mesh
#mean = GPR.predict(gprw, mesh) # `return_std` is false by default

################################################################################
# plot section #################################################################
################################################################################
plt.plot(gpr_data[:,1], gpr_data[:,2], "r.", ms = 6, label = "Data points")
plt.plot(mesh, mean, "k", lw = 2.5, label = "GPR mean")
plt.fill_between(mesh, mean - 2*std, mean + 2*std, alpha = 0.4, zorder = 10,
color = "k", label = "95% interval")
#GPR.plot_fit(gprw, plt, plot_95 = true) # by default, `plot_95` is false

# no legend by default, but you can specify yours in the following order:
# data points, subsample, mean, 95% interval
GPR.plot_fit(gprw, plt, plot_95 = true,
label = ["Points", "Training", "Mean", "95% interval"])

plt.legend()
plt.show()
Expand Down
140 changes: 140 additions & 0 deletions src/GPR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Functions that operate on GPR.Wrap struct:
- subsample! (3 methods)
- learn! (1 method)
- predict (1 method)
- mmstd (1 method)
- plot_fit (1 method)

Do *not* set Wrap's variables except for `thrsh`; use setter functions!
"""
Expand Down Expand Up @@ -184,6 +186,144 @@ function predict(gprw::Wrap, x; return_std = false)
end
end

"""
Return mesh, mean and st. deviation over the whole data range

Computes min and max of `gprw.data` x-range and returns equispaced mesh with
`mesh_n` number of points, mean and st. deviation computed over that mesh

Parameters:
- gprw: an instance of GPR.Wrap
- mesh_n: number of mesh points (1001 by default)

Returns:
- (m, m, std): mesh, mean and st. deviation
"""
function mmstd(gprw::Wrap; mesh_n = 1001)
mesh = range(minimum(gprw.data, dims=1)[1],
maximum(gprw.data, dims=1)[1],
length = mesh_n)
return (mesh, predict(gprw, mesh, return_std = true)...)
end

"""
Plot mean (and 95% interval) along with data and subsample

The flag `plot_95` controls whether to plot 95% interval; `label` may provide
labels in the following order:
data points, subsample, GP mean, 95% interval (if requested)

If you're using Plots.jl and running it from a file rather than REPL, you need
to wrap the call:
`display(GPR.plot_fit(gprw, Plots))`

Parameters:
- gprw: an instance of GPR.Wrap
- plt: a module used for plotting (only PyPlot & Plots supported)
- plot_95: boolean flag, whether to plot 95% confidence interval
- label: a 3- or 4-tuple or vector of strings (no label by default)
"""
function plot_fit(gprw::Wrap, plt; plot_95 = false, label = nothing)
if !gprw.__data_set
println(warn("plot_fit"), "data is not set, nothing to plot")
return
end
is_pyplot = (Symbol(plt) == :PyPlot)
is_plots = (Symbol(plt) == :Plots)
if !is_pyplot && !is_plots
println(warn("plot_fit"), "only PyPlot & Plots are supported; not plotting")
return
end

# set `cols` Dict with colors of the plots
alpha_95 = 0.3 # alpha channel for shaded region, i.e. 95% interval
cols = Dict{String, Any}()
cols["mean"] = "black"
if is_pyplot
cols["data"] = "tab:gray"
cols["sub"] = "tab:red"
cols["shade"] = (0, 0, 0, alpha_95)
elseif is_plots
cols["data"] = "#7f7f7f" # tab10 gray
cols["sub"] = "#d62728" # tab10 red
end

# set keyword argument dictionaries for plotting functions
kwargs_data = Dict{Symbol, Any}()
kwargs_sub = Dict{Symbol, Any}()
kwargs_mean = Dict{Symbol, Any}()
kwargs_95 = Dict{Symbol, Any}()
kwargs_aux = Dict{Symbol, Any}()

kwargs_data[:color] = cols["data"]
kwargs_sub[:color] = cols["sub"]
kwargs_mean[:color] = cols["mean"]
kwargs_mean[:lw] = 2.5
if is_pyplot
kwargs_data[:ms] = 4
kwargs_sub[:ms] = 4
kwargs_95[:facecolor] = cols["shade"]
kwargs_95[:edgecolor] = cols["mean"]
kwargs_95[:lw] = 0.5
kwargs_95[:zorder] = 10
elseif is_plots
kwargs_data[:ms] = 2
kwargs_sub[:ms] = 2
kwargs_95[:color] = cols["mean"]
kwargs_95[:fillalpha] = alpha_95
kwargs_95[:lw] = 2.5
kwargs_95[:z] = 10
kwargs_aux[:color] = cols["mean"]
kwargs_aux[:lw] = 0.5
kwargs_aux[:label] = ""
end

if label != nothing
kwargs_data[:label] = label[1]
kwargs_sub[:label] = label[2]
kwargs_mean[:label] = label[3]
if is_pyplot
kwargs_95[:label] = label[4]
elseif is_plots
kwargs_95[:label] = label[3]
end
elseif is_plots
kwargs_data[:label] = ""
kwargs_sub[:label] = ""
kwargs_mean[:label] = ""
kwargs_95[:label] = ""
end


mesh, mean, std = mmstd(gprw)

# plot data, subsample and mean
if is_pyplot
plt.plot(gprw.data[:,1], gprw.data[:,2], "."; kwargs_data...)
plt.plot(gprw.subsample[:,1], gprw.subsample[:,2], "."; kwargs_sub...)
if plot_95
plt.fill_between(mesh,
mean - 1.96 * std,
mean + 1.96 * std;
kwargs_95...)
end
plt.plot(mesh, mean; kwargs_mean...)
elseif is_plots
plt.scatter!(gprw.data[:,1], gprw.data[:,2]; kwargs_data...)
plt.scatter!(gprw.subsample[:,1], gprw.subsample[:,2]; kwargs_sub...)
if plot_95
plt.plot!(mesh,
mean,
ribbon = (1.96 * std, 1.96 * std);
kwargs_95...)
plt.plot!(mesh, mean - 1.96 * std; kwargs_aux...)
plt.plot!(mesh, mean + 1.96 * std; kwargs_aux...)
else
plt.plot!(mesh, mean; kwargs_mean...)
end
end
end

################################################################################
# convenience functions ########################################################
################################################################################
Expand Down
Binary file added test/GPR/data/matern_05_mesh.npy
Binary file not shown.
Binary file modified test/GPR/data/matern_def_mean.npy
Binary file not shown.
Binary file modified test/GPR/data/matern_def_std.npy
Binary file not shown.
Binary file modified test/GPR/data/mesh.npy
Binary file not shown.
Binary file modified test/GPR/data/rbf_mean.npy
Binary file not shown.
Binary file modified test/GPR/data/rbf_std.npy
Binary file not shown.
14 changes: 11 additions & 3 deletions test/GPR/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const rbf_mean = NPZ.npzread(joinpath(data_dir, "rbf_mean.npy"))
const rbf_std = NPZ.npzread(joinpath(data_dir, "rbf_std.npy"))
const matern_def_mean = NPZ.npzread(joinpath(data_dir, "matern_def_mean.npy"))
const matern_def_std = NPZ.npzread(joinpath(data_dir, "matern_def_std.npy"))
const matern_05_mesh = NPZ.npzread(joinpath(data_dir, "matern_05_mesh.npy"))
const matern_05_mean = NPZ.npzread(joinpath(data_dir, "matern_05_mean.npy"))
const matern_05_std = NPZ.npzread(joinpath(data_dir, "matern_05_std.npy"))

Expand Down Expand Up @@ -88,9 +89,15 @@ thrsh = gprw.thrsh

ytrue = xmesh.^2
ypred = GPR.predict(gprw, xmesh)
@test isapprox(ytrue, ypred, atol=1e-5, norm=inf_norm)
@test isapprox(ytrue, ypred, atol=1e-4, norm=inf_norm)

mesh_, mean_, std_ = GPR.mmstd(gprw, mesh_n=101)
@test isapprox(mesh_, xmesh, atol=1e-8, norm=inf_norm)
@test isapprox(ypred, mean_, atol=1e-8, norm=inf_norm)

mean, std = GPR.predict(gprw, xmesh, return_std = true)
@test isapprox(mean, mean_, atol=1e-8, norm=inf_norm)
@test isapprox(std, std_, atol=1e-8, norm=inf_norm)
@test ndims(mean) == 1
@test ndims(std) == 1
@test size(std,1) == size(mean,1)
Expand All @@ -107,9 +114,10 @@ GPR.set_data!(gprw, gpr_data)
gprw.thrsh = -1
@testset "non-synthetic testing" begin
GPR.learn!(gprw)
mean, std = GPR.predict(gprw, gpr_mesh, return_std = true)
mesh, mean, std = GPR.mmstd(gprw)
@test size(gprw.data,1) == 800
@test size(gprw.subsample,1) == 800
@test isapprox(mesh, gpr_mesh, atol=1e-8, norm=inf_norm)
@test isapprox(mean, rbf_mean, atol=1e-3, norm=inf_norm)
@test isapprox(std, rbf_std, atol=1e-3, norm=inf_norm)

Expand All @@ -119,7 +127,7 @@ gprw.thrsh = -1
@test isapprox(std, matern_def_std, atol=1e-3, norm=inf_norm)

GPR.learn!(gprw, kernel = "matern", nu = 0.5)
mean, std = GPR.predict(gprw, gpr_mesh, return_std = true)
mean, std = GPR.predict(gprw, matern_05_mesh, return_std = true)
@test isapprox(mean, matern_05_mean, atol=1e-3, norm=inf_norm)
@test isapprox(std, matern_05_std, atol=1e-3, norm=inf_norm)
end
Expand Down