Skip to content

Commit

Permalink
added parameter to choose whose cost function is being graphed on the…
Browse files Browse the repository at this point in the history
… z-axis with the controls, via cost_player parameter.
  • Loading branch information
jessemilzman committed Feb 25, 2024
1 parent f2ac036 commit b029c64
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions experiments/tower_defense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,38 +101,41 @@ end
"""
Temp. script to calculate and plot surfaces for the terms in Stage 1's cost function
"""
function run_stage_1_breakout(;display_controls = 0, dr = 0.05)
function run_stage_1_breakout(;display_controls = 0, dr = 0.05, cost_player = 1)
# dr = 0.05
ps = [1/3, 1/3, 1/3]
βs = [
[3.0, 2.0, 2.0],
[2.0, 3., 2.0],
[2.0, 2.0, 3.0]
]
# initial_guess = vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(42))
# primal_guess = vcat(repeat([0.5,0.25,0.25],4), repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2))
# primal_guess = vcat([0.34,0.33,0.33,0.5,0.25,0.25,0.25,0.5,0.25,0.25,0.25,0.5], repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2))
primal_guess = (1/3)*ones(30)
initial_guess = vcat(primal_guess,(1/3)*ones(42))
#### Choose the initial guess for Stage 2 initialization
primal_guess = (1/3)*ones(30) ## Initialization frorm primes
initial_guess = vcat(primal_guess,(1/3)*ones(42)) ## concatenate, assume duals are 1/3



if (display_controls in [1,2])
world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess)
world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess)
world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess)
world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess)
world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess)
world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess)
world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player)
else
world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess)
world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess)
world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess)
world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess)
world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess)
world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess)
world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess, cost_player=cost_player)
world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess, cost_player=cost_player)
world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess, cost_player=cost_player)
world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess, cost_player=cost_player)
world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess, cost_player=cost_player)
world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess, cost_player=cost_player)
end
# Normalize using maximum value across all worlds

maxormin = cost_player == 2 ? minimum : maximum

max_value =
maximum(
maxormin(
filter(
!isnan,
vcat(
Expand All @@ -145,6 +148,7 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05)
),
),
)
max_value = (-1)^(cost_player+1)*max_value
world_1_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_1_misid_costs]
world_2_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_2_misid_costs]
world_3_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_3_misid_costs]
Expand All @@ -170,7 +174,9 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05)
world_2_misid_controls,
world_3_misid_controls,
],
ps, save_file="P"*string(display_controls)*"_"
ps,
save_file="P"*string(display_controls)*"_",
cost_player=cost_player
)
else
display_stage_1_costs(
Expand Down Expand Up @@ -231,7 +237,7 @@ function calculate_residuals(ps, βs, world_idx; dr = 0.05)
return residuals
end

function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing)
function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing, cost_player = 1)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
Expand All @@ -245,6 +251,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in
return_controls = 0
end
end
J = cost_player == 2 ? J_2 : J_1

for (i, r1) in enumerate(rs)
for (j, r2) in enumerate(rs)
Expand All @@ -257,7 +264,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in
id_cost =
r[world_idx] *
ps[world_idx] *
J_1(
J(
x[Block(world_idx + 1)],
x[Block(world_idx + 2 * num_worlds + 1)],
βs[world_idx],
Expand All @@ -277,7 +284,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in
end
end

function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing)
function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing, cost_player = 1)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
Expand All @@ -291,6 +298,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls =
return_controls = 0
end
end
J = cost_player == 2 ? J_2 : J_1

for (i, r1) in enumerate(rs)
for (j, r2) in enumerate(rs)
Expand All @@ -302,7 +310,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls =
x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess)
defender_signal_0 = x[Block(1)]
attacker_signal_0_world_idx = x[Block(world_idx + num_worlds + 1)]
misid_cost = J_1(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx])
misid_cost = J(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx])
misid_cost = (1 - r[world_idx]) * ps[world_idx] * misid_cost # weight by p(w_k|s¹=0)
misid_costs[i, j] = misid_cost
if (return_controls > 0)
Expand Down Expand Up @@ -454,11 +462,11 @@ Input:
Output:
fig: Figure with simplex heatmap
"""
function display_stage_1_costs_controls(costs, controls, ps; save_file = "")
function display_stage_1_costs_controls(costs, controls, ps; save_file = "", cost_player=1)
rs = 0:(1 / (size(costs[1])[1] - 1)):1
num_worlds = length(ps)
fig = Figure(size = (1500, 1000), title = "test")
max_value = 1.0
ylims = cost_player == 2 ? (-1.0,0.0) : (0.01, 1.0) ## either graph from y=0,1 (for normalized cost for P1), or else y=-1,0 (for P2)
axs = [
[
Axis3(
Expand All @@ -474,7 +482,7 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "")
ylabel = "r₂",
zlabel = "Cost",
title = "W$world_idx, S$world_idx",
limits = (nothing, nothing, (0.01, max_value)),
limits = (nothing, nothing, ylims),
) for world_idx in 1:num_worlds
],
[
Expand All @@ -491,7 +499,7 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "")
ylabel = "r₂",
zlabel = "Cost",
title = "W$world_idx, S0",
limits = (nothing, nothing, (0.01, max_value)),
limits = (nothing, nothing, ylims),
) for world_idx in 1:num_worlds
],
]
Expand All @@ -505,8 +513,8 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "")
rs[jj],
costs[world_idx][ii,jj],
color = colors[ii,jj],
colormap = :viridis,
colorrange = (0, max_value),
# colormap = :viridis,
# colorrange = (0, max_value),
)
end
end
Expand All @@ -526,8 +534,8 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "")
rs[jj],
costs[world_idx+num_worlds][ii,jj],
color = colors[ii,jj],
colormap = :viridis,
colorrange = (0, max_value),
# colormap = :viridis,
# colorrange = (0, max_value),
)
end
end
Expand Down

0 comments on commit b029c64

Please sign in to comment.