Skip to content

Commit

Permalink
merge fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jessemilzman committed Feb 20, 2024
2 parents eaa6a26 + be426dc commit d67cdcd
Showing 1 changed file with 88 additions and 4 deletions.
92 changes: 88 additions & 4 deletions experiments/tower_defense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,49 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05)

end

function run_residuals()
dr = 0.01
ps = [1/3, 1/3, 1/3]
βs = [
[4.0, 2.0, 2.0],
[2.0, 4.0, 2.0],
[2.0, 2.0, 4.0]
]
world_1_residuals = calculate_residuals(ps, βs, 1; dr)
world_2_residuals = calculate_residuals(ps, βs, 2; dr)
world_3_residuals = calculate_residuals(ps, βs, 3; dr)

display_residuals(
[
world_1_residuals,
world_2_residuals,
world_3_residuals,
],
ps,
)
end

function calculate_residuals(ps, βs, world_idx; dr = 0.05)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
num_worlds = length(ps)
residuals = NaN * ones(Float64, Int(1 / dr + 1), Int(1 / dr + 1))
for (i, r1) in enumerate(rs)
for (j, r2) in enumerate(rs)
if r1 + r2 > 1
continue
end
r3 = 1 - r1 - r2
r = [r1, r2, r3]
_, residual = compute_stage_2(r, ps, βs, game; return_residual = true)
residuals[i, j] = residual
end
end

return residuals
end

function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
Expand Down Expand Up @@ -283,7 +326,7 @@ Output:
function display_stage_1_costs(costs, ps)
rs = 0:(1 / (size(costs[1])[1] - 1)):1
num_worlds = length(ps)
fig = Figure(size = (1500, 1000), title = "test")
fig = Figure(size = (1500, 800), title = "test")
axs = [
[
Axis3(
Expand Down Expand Up @@ -361,6 +404,42 @@ function display_stage_1_costs(costs, ps)
fig
end


function display_residuals(costs, ps)
rs = 0:(1 / (size(costs[1])[1] - 1)):1
num_worlds = length(ps)
fig = Figure(size = (1500, 500), title = "test")
axs = [
Axis3(
fig[1, world_idx],
aspect = (1, 1, 1),
perspectiveness = 0.5,
elevation = pi / 5,
azimuth = -π * (1 / 2 + 1 / 4),
zgridcolor = :grey,
ygridcolor = :grey,
xgridcolor = :grey;
xlabel = "r₁",
ylabel = "r₂",
zlabel = "Residual",
title = "World $world_idx",
# limits = (nothing, nothing, (0.01, 1)),
) for world_idx in 1:num_worlds
]
for world_idx in 1:num_worlds
hmap = surface!(
axs[world_idx],
rs,
rs,
costs[world_idx],
colormap = :viridis,
# colorrange = (0, 1),
)

end
fig
end

"""
Display surface of Stage 1's objective function, colored according to a player's action. Assumes number of worlds is 3.
Expand Down Expand Up @@ -720,18 +799,23 @@ Input:
Output:
x: decision variables of Stage 2 given r. BlockedArray with a block per player
"""
function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = false)
function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = false, return_residual = false)
n = length(ps) # assume n_signals = n_worlds + 1
n_players = 1 + n^2
var_dim = n # TODO: Change this to be more general

solution = solve(
game,
r;
initial_guess = isnothing(initial_guess) ? zeros(total_dim(game)) : initial_guess,
initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess,
verbose = verbose,
return_primals = false,
)

BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players])
if return_residual
return BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players]),
solution.info.residual
else
return BlockArray(solution.variables[1:(n_players * var_dim)], [n for _ in 1:n_players])
end
end

0 comments on commit d67cdcd

Please sign in to comment.