Skip to content

Commit

Permalink
fixed gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
jessemilzman committed Jan 17, 2024
1 parent 8eb94c3 commit c41d92d
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions experiments/tower_defense_exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,7 @@ function solve_r(pws, ws; r_init = [1/3, 1/3, 1/3], iter_limit=50, target_error=
while cur_iter < iter_limit # TODO: Break if change from last iteration is small
dJdr = compute_dJdr(r, x, pws, ws, game)
r_temp = r - α .* dJdr
# println("α .* dJdr = $(r-r_temp)")
# println("r - α .* dJdr = $r_temp")
# r_temp = max.(0, min.(1, r_temp)) # project onto [0,1]
r_temp = max.(0,r_temp) # project onto [0,1]
# println("max(...) = $r_temp")
r_temp = r_temp / sum(r_temp) # project onto (n-1) simplex
# println("normalized = $r_temp")
r = r_temp
r = project_onto_simplex(r_temp)
x = compute_stage_2(
r, pws, ws, game;
initial_guess=vcat(x, zeros(total_dim(game) - n_players * var_dim))
Expand All @@ -66,7 +59,7 @@ function solve_r(pws, ws; r_init = [1/3, 1/3, 1/3], iter_limit=50, target_error=
cur_iter += 1
println("$cur_iter: r = $r")
# println("x = $x \n")
print_state(x)
# print_state(x)
end
println("$cur_iter: r = $r")
if return_states
Expand All @@ -78,6 +71,17 @@ function solve_r(pws, ws; r_init = [1/3, 1/3, 1/3], iter_limit=50, target_error=
return r
end

"""
Project onto simplex using Fig. 1 Duchi 2008
"""
function project_onto_simplex(v; z=1.0)
μ = sort(v, rev=true)
ρ = findfirst([μ[j] - 1/j * (sum(μ[1:j]) - z) <= 0 for j in eachindex(v)])
ρ = isnothing(ρ) ? length(v) : ρ - 1
θ = 1/ρ * (sum(μ[1:ρ]) - z)
return [maximum([v[i] - θ, 0]) for i in eachindex(v)]
end


function print_state(x)
out = reshape(x,3,10)
Expand All @@ -92,7 +96,7 @@ end
"Attacker cost function"
function J_2(u, v, w)
m = length(w)
sum(w[ii]^(v[ii]-u[ii]) for ii=1:m)
sum([w[ii]^(v[ii]-u[ii]) for ii=1:m])
# (u[w] - v[w]) # P2 only cares about a SINGLE direction.
end

Expand All @@ -119,7 +123,7 @@ function build_stage_2(pws, ws)
(x, θ) -> sum([J_1(x[Block(1)], x[Block(w_idx + n + 1)]) * p_w_k_0(w_idx, θ) for w_idx in 1:n]), # u|s¹=0 IPI
[(x, θ) -> J_2(x[Block(1)], x[Block(w_idx + n + 1)], ws[w_idx]) for w_idx in 1:n]..., # v|s¹=0 IPI
[(x, θ) -> J_1(x[Block(w_idx + 1)], x[Block(w_idx + 2 * n + 1)]) for w_idx in 1:n]..., # u|s¹={1,2,3} PI
[(x, θ) -> J_2(x[Block(w_idx + 1)], x[Block(w_idx + 2 * n + 1)], ws[w_idx]) for w_idx in 1:n]..., # v|s¹={1,2,3} PI
[(x, θ) -> J_2(x[Block(w_idx + 1)], x[Block(w_idx + 2 * n + 1)], ws[w_idx]) for w_idx in 1:n]... # v|s¹={1,2,3} PI
]

# equality constraints
Expand Down Expand Up @@ -152,8 +156,7 @@ Compute objective at Stage 1
"""
function compute_J(r, x, pws, ws)
n = length(pws)
-sum((1 - r[w_idx]) * pws[w_idx] * J_1(x[Block(1)], x[Block(w_idx + n + 1)]) for w_idx in 1:n)
-sum(r[w_idx] * pws[w_idx] * J_2(x[Block(w_idx + 1)], x[Block(w_idx + 2 * n + 1)], ws[w_idx]) for w_idx in 1:n)
-sum([(1 - r[w_idx]) * pws[w_idx] * J_1(x[Block(1)], x[Block(w_idx + n + 1)]) for w_idx in 1:n]) -sum([r[w_idx] * pws[w_idx] * J_2(x[Block(w_idx + 1)], x[Block(w_idx + 2 * n + 1)], ws[w_idx]) for w_idx in 1:n])
end

"""
Expand All @@ -177,6 +180,13 @@ function compute_dJdr(r, x, pws, ws, game)
dJdx = compute_dJdx(r, x, pws, ws)
dJdr = gradient(r -> compute_J(r, x, pws, ws), r)[1]
dxdr = compute_dxdr(r, x, pws, ws, game)

dJdr_norm = norm_sqr(dJdx)
dxdr_norm = norm_sqr(dxdr)

println("dJdr = $dJdr_norm")
println("dxdr = $dxdr_norm")

n = length(pws)
for idx in 1:(1 + n^2)
dJdr += (dJdx[Block(idx)]' * dxdr[Block(idx)])'
Expand Down

0 comments on commit c41d92d

Please sign in to comment.