Skip to content

Commit

Permalink
Merge pull request #8 from CLeARoboticsLab/fernando/zero-sum-subgame
Browse files Browse the repository at this point in the history
Fernando/zero sum subgame
  • Loading branch information
fernandopalafox authored Feb 20, 2024
2 parents 71ad7fd + 65b5073 commit be426dc
Showing 1 changed file with 213 additions and 33 deletions.
246 changes: 213 additions & 33 deletions experiments/tower_defense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,137 @@ end
Temp. script to calculate and plot heatmap of Stage 1 cost function
"""
function run_visualization()
ps = [1 / 3, 1 / 3, 1 / 3]
dr = 0.01
ps = [1/3, 1 / 3, 1 / 3]
βs = [[2, 1, 1], [1, 2, 1], [1, 1, 2]]
Ks = calculate_stage_1_costs(ps, βs)
Ks = calculate_stage_1_costs(ps, βs; dr)
fig = display_surface(ps, Ks)
fig
end

"""
Temp. script to calculate and plot heatmap for cost of misidentifying a world
Temp. script to calculate and plot surfaces for the terms in Stage 1's cost function
"""
function run_misid_vis()
function run_stage_1_breakout()
dr = 0.05
ps = [1/3, 1/3, 1/3]
βs = [[2, 1, 1], [1, 2, 1], [1, 1, 2]]
βs = [
[2.1, 2.0, 2.0],
[2.0, 2.1, 2.0],
[2.0, 2.0, 2.1]
]
world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr)
world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr)
world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr)
world_1_id_costs = calculate_id_costs(ps, βs, 1; dr)
world_2_id_costs = calculate_id_costs(ps, βs, 2; dr)
world_3_id_costs = calculate_id_costs(ps, βs, 3; dr)

# Normalize using maximum value across all worlds
# Normalize using maximum value across all worlds
max_value =
maximum(filter(!isnan, vcat(world_1_misid_costs, world_2_misid_costs, world_3_misid_costs)))
maximum(
filter(
!isnan,
vcat(
world_1_misid_costs,
world_2_misid_costs,
world_3_misid_costs,
world_1_id_costs,
world_2_id_costs,
world_3_id_costs,
),
),
)
world_1_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_1_misid_costs]
world_2_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_2_misid_costs]
world_3_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_3_misid_costs]
world_1_id_costs = [isnan(c) ? NaN : c / max_value for c in world_1_id_costs]
world_2_id_costs = [isnan(c) ? NaN : c / max_value for c in world_2_id_costs]
world_3_id_costs = [isnan(c) ? NaN : c / max_value for c in world_3_id_costs]

display_misid_costs([world_1_misid_costs, world_2_misid_costs, world_3_misid_costs], ps)
display_stage_1_costs(
[
world_1_id_costs,
world_2_id_costs,
world_3_id_costs,
world_1_misid_costs,
world_2_misid_costs,
world_3_misid_costs,
],
ps,
)
end

function run_residuals()
dr = 0.01
ps = [1/3, 1/3, 1/3]
βs = [
[4.0, 2.0, 2.0],
[2.0, 4.0, 2.0],
[2.0, 2.0, 4.0]
]
world_1_residuals = calculate_residuals(ps, βs, 1; dr)
world_2_residuals = calculate_residuals(ps, βs, 2; dr)
world_3_residuals = calculate_residuals(ps, βs, 3; dr)

display_residuals(
[
world_1_residuals,
world_2_residuals,
world_3_residuals,
],
ps,
)
end

function calculate_residuals(ps, βs, world_idx; dr = 0.05)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
num_worlds = length(ps)
residuals = NaN * ones(Float64, Int(1 / dr + 1), Int(1 / dr + 1))
for (i, r1) in enumerate(rs)
for (j, r2) in enumerate(rs)
if r1 + r2 > 1
continue
end
r3 = 1 - r1 - r2
r = [r1, r2, r3]
_, residual = compute_stage_2(r, ps, βs, game; return_residual = true)
residuals[i, j] = residual
end
end

return residuals
end

function calculate_id_costs(ps, βs, world_idx; dr = 0.05)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
num_worlds = length(ps)
id_costs = NaN * ones(Float64, Int(1 / dr + 1), Int(1 / dr + 1))
for (i, r1) in enumerate(rs)
for (j, r2) in enumerate(rs)
if r1 + r2 > 1
continue
end
r3 = 1 - r1 - r2
r = [r1, r2, r3]
x = compute_stage_2(r, ps, βs, game)
id_cost =
r[world_idx] *
ps[world_idx] *
J_1(
x[Block(world_idx + 1)],
x[Block(world_idx + 2 * num_worlds + 1)],
βs[world_idx],
)
id_costs[i, j] = id_cost
end
end

return id_costs
end

function calculate_misid_costs(ps, βs, world_idx; dr = 0.05)
Expand All @@ -132,6 +237,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05)
defender_signal_0 = x[Block(1)]
attacker_signal_0_world_idx = x[Block(world_idx + num_worlds + 1)]
misid_cost = J_1(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx])
misid_cost = (1 - r[world_idx]) * ps[world_idx] * misid_cost # weight by p(w_k|s¹=0)
misid_costs[i, j] = misid_cost
end
end
Expand All @@ -147,9 +253,91 @@ Input:
Output:
fig: Figure with simplex heatmap
"""
function display_misid_costs(costs, ps)
function display_stage_1_costs(costs, ps)
rs = 0:(1 / (size(costs[1])[1] - 1)):1
num_worlds = length(ps)
fig = Figure(size = (1500, 800), title = "test")
axs = [
[
Axis3(
fig[1, world_idx],
aspect = (1, 1, 1),
perspectiveness = 0.5,
elevation = pi / 5,
azimuth = -π * (1 / 2 + 1 / 4),
zgridcolor = :grey,
ygridcolor = :grey,
xgridcolor = :grey;
xlabel = "r₁",
ylabel = "r₂",
zlabel = "Cost",
title = "World $world_idx",
limits = (nothing, nothing, (0.01, 1)),
) for world_idx in 1:num_worlds
],
[
Axis3(
fig[2, world_idx],
aspect = (1, 1, 1),
perspectiveness = 0.5,
elevation = pi / 5,
azimuth = -π * (1 / 2 + 1 / 4),
zgridcolor = :grey,
ygridcolor = :grey,
xgridcolor = :grey;
xlabel = "r₁",
ylabel = "r₂",
zlabel = "Cost",
title = "World $world_idx",
limits = (nothing, nothing, (0.01, 1)),
) for world_idx in 1:num_worlds
],
]
for world_idx in 1:num_worlds
hmap = surface!(
axs[1][world_idx],
rs,
rs,
costs[world_idx],
colormap = :viridis,
colorrange = (0, 1),
)
# text!(axs[world_idx], "$(round(ps[1], digits=2))", position = (0.9, 0.4, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[2], digits=2))", position = (0.1, 0.95, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[3], digits=2))", position = (0.2, 0.1, cost_min), font = "Bold")

end
for world_idx in 1:num_worlds
hmap = surface!(
axs[2][world_idx],
rs,
rs,
costs[world_idx + num_worlds],
colormap = :viridis,
colorrange = (0, 1),
)
# text!(axs[world_idx], "$(round(ps[1], digits=2))", position = (0.9, 0.4, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[2], digits=2))", position = (0.1, 0.95, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[3], digits=2))", position = (0.2, 0.1, cost_min), font = "Bold")

if world_idx == num_worlds
Colorbar(
fig[1:2, num_worlds + 1],
hmap;
label = "Cost",
width = 15,
ticksize = 15,
tickalign = 1,
)
end
end
fig
end


function display_residuals(costs, ps)
rs = 0:(1 / (size(costs[1])[1] - 1)):1
num_worlds = length(costs)
num_worlds = length(ps)
fig = Figure(size = (1500, 500), title = "test")
axs = [
Axis3(
Expand All @@ -163,9 +351,9 @@ function display_misid_costs(costs, ps)
xgridcolor = :grey;
xlabel = "r₁",
ylabel = "r₂",
zlabel = "Misid. cost",
zlabel = "Residual",
title = "World $world_idx",
limits = (nothing, nothing, (0.01, 1)),
# limits = (nothing, nothing, (0.01, 1)),
) for world_idx in 1:num_worlds
]
for world_idx in 1:num_worlds
Expand All @@ -175,22 +363,9 @@ function display_misid_costs(costs, ps)
rs,
costs[world_idx],
colormap = :viridis,
colorrange = (0, 1),
# colorrange = (0, 1),
)
# text!(axs[world_idx], "$(round(ps[1], digits=2))", position = (0.9, 0.4, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[2], digits=2))", position = (0.1, 0.95, cost_min), font = "Bold")
# text!(axs[world_idx], "$(round(ps[3], digits=2))", position = (0.2, 0.1, cost_min), font = "Bold")

if world_idx == num_worlds
Colorbar(
fig[1, world_idx + 1],
hmap;
label = "Misid. cost",
width = 15,
ticksize = 15,
tickalign = 1,
)
end
end
fig
end
Expand Down Expand Up @@ -248,23 +423,23 @@ function display_surface(ps, Ks)
fig[1, 1],
aspect = (1, 1, 1),
perspectiveness = 0.5,
elevation = pi / 20,
elevation = pi / 4,
azimuth = -π * (1 / 2 + 1 / 4),
zgridcolor = :grey,
ygridcolor = :grey,
xgridcolor = :grey;
xlabel = "r₁",
ylabel = "r₂",
zlabel = "K",
title = "Stage 1 cost as a function of r \n priors = $(round.(ps, digits=2))",
title = "Normalized stage 1 cost\n priors = $(round.(ps, digits=2))",
limits = (nothing, nothing, (0.01, 1)),
)
Ks_min = minimum(filter(!isnan, Ks))
hmap = surface!(ax, rs, rs, Ks, colorrange = (0, 1))
Colorbar(fig[1, 2], hmap; label = "K", width = 15, ticksize = 15, tickalign = 1)
text!(ax, "$(round(ps[1], digits=2))", position = (0.9, 0.4, 0.01), font = "Bold")
text!(ax, "$(round(ps[1], digits=2))", position = (0.9, 0.2, 0.01), font = "Bold")
text!(ax, "$(round(ps[2], digits=2))", position = (0.1, 0.95, 0.01), font = "Bold")
text!(ax, "$(round(ps[3], digits=2))", position = (0.2, 0.1, 0.01), font = "Bold")
text!(ax, "$(round(ps[3], digits=2))", position = (0.1, 0.2, 0.01), font = "Bold")
fig
end

Expand Down Expand Up @@ -428,18 +603,23 @@ Input:
Output:
x: decision variables of Stage 2 given r. BlockedArray with a block per player
"""
function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = false)
function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = false, return_residual = false)
n = length(ps) # assume n_signals = n_worlds + 1
n_players = 1 + n^2
var_dim = n # TODO: Change this to be more general

solution = solve(
game,
r;
initial_guess = isnothing(initial_guess) ? zeros(total_dim(game)) : initial_guess,
initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess,
verbose = verbose,
return_primals = false,
)

BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players])
if return_residual
return BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players]),
solution.info.residual
else
return BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players])
end
end

0 comments on commit be426dc

Please sign in to comment.