Skip to content

Commit

Permalink
Merge pull request #35 from CLeARoboticsLab/hmzh/utils-refactor
Browse files Browse the repository at this point in the history
Create type and structs to abstract dynamics and costs beyond LQ assumptions.
  • Loading branch information
hmzh-khn authored Nov 29, 2022
2 parents 980b227 + de30589 commit b2aab6f
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 122 deletions.
6 changes: 3 additions & 3 deletions example/CouplingExample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ function coupling_example()
0.1 0;
0 0;
0 0.1])
dyn = Dynamics(A, [B₁, B₂])
dyn = LinearDynamics(A, [B₁, B₂])

# Costs reflecting the preferences above.
Q₁ = zeros(8, 8)
Q₁[5, 5] = 1.0
Q₁[7, 7] = 1.0
c₁ = Cost(Q₁)
c₁ = QuadraticCost(Q₁)
add_control_cost!(c₁, 1, 1 * diagm([1, 1]))
add_control_cost!(c₁, 2, zeros(2, 2))

Expand All @@ -51,7 +51,7 @@ function coupling_example()
Q₂[7, 7] = 1.0
Q₂[3, 7] = -1.0
Q₂[7, 3] = -1.0
c₂ = Cost(Q₂)
c₂ = QuadraticCost(Q₂)
add_control_cost!(c₂, 2, 1 * diagm([1, 1]))
add_control_cost!(c₂, 1, zeros(2, 2))

Expand Down
59 changes: 59 additions & 0 deletions src/CostUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Utilities for managing quadratic and nonquadratic dynamics.
abstract type Cost end

# Every Cost is assumed to have the following functions defined on it:
# - quadraticize_costs(cost, t, x, us) - this function produces a QuadraticCost at time t given the state and controls
# - evaluate(cost, xs, us) - this function evaluates the cost of a trajectory given the states and controls

# We use this type by making a substruct of it which can then have certain functions defined for it.
abstract type NonQuadraticCost <: Cost end

# Cost for a single player.
# Form is: x^T_t Q^i x + \sum_j u^{jT}_t R^{ij} u^j_t.
# For simplicity, assuming that Q, R are time-invariant, and that dynamics are
# linear time-invariant, i.e. x_{t+1} = A x_t + \sum_i B^i u^i_t.
mutable struct QuadraticCost <: Cost
Q
Rs
end
QuadraticCost(Q) = QuadraticCost(Q, Dict{Int, Matrix{eltype(Q)}}())

# TODO(hamzah) Add better tests for the QuadraticCost struct and associated functions.

# Method to add R^{ij}s to a Cost struct.
export add_control_cost!
function add_control_cost!(c::QuadraticCost, other_player_idx, Rij)
c.Rs[other_player_idx] = Rij
end

function quadraticize_costs(cost::QuadraticCost, t, x, us)
return cost
end

# Evaluate cost on a state/control trajectory.
# - xs[:, time]
# - us[player][:, time]
function evaluate(c::QuadraticCost, xs, us)
horizon = last(size(xs))

total = 0.0
for tt in 1:horizon
total += xs[:, tt]' * c.Q * xs[:, tt]
total += sum(us[jj][:, tt]' * Rij * us[jj][:, tt] for (jj, Rij) in c.Rs)
end
return total
end


# TODO: Make the affine cost structure with homogenenized coordinates.
# struct AffineCost <: Cost end
# function quadraticize_costs(cost::AffineCost, t, x, us)
# function evaluate(c::AffineCost, xs, us)
# end

# Export all the cost types/structs.
export Cost, NonQuadraticCost, QuadraticCost

# Export all the cost types/structs.
export quadraticize_costs, evaluate

114 changes: 114 additions & 0 deletions src/DynamicsUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Utilities for managing linear and nonlinear dynamics.

# Every Dynamics is assumed to have the following functions defined on it:
# - linearize_dynamics(dyn, x, us) - this function linearizes the dynamics given the state and controls.
# - propagate_dynamics(cost, t, x, us) - this function propagates the dynamics to the next timestep.
# Every Dynamics struct must have a sys_info field of type SystemInfo.
abstract type Dynamics end

# A type that every nonlinear dynamics struct (unique per use case) can inherit from. These need to have the same
# functions as the Dynamics type.
abstract type NonlinearDynamics <: Dynamics end

# TODO(hamzah) Add better tests for the LinearDynamics struct and associated functions.
struct LinearDynamics <: Dynamics
A # state
Bs # controls
sys_info::SystemInfo
end
# Constructor for linear dynamics that auto-generates the system info.
LinearDynamics(A, Bs) = LinearDynamics(A, Bs, SystemInfo(length(Bs), last(size(A)), [last(size(Bs[i])) for i in 1:length(Bs)]))

function propagate_dynamics(dyn::LinearDynamics, t, x, us)
N = dyn.sys_info.num_agents
x_next = dyn.A * x
for i in 1:N
ui = reshape(us[i], dyn.sys_info.num_us[i], 1)
x_next += dyn.Bs[i] * ui
end
return x_next
end

function linearize_dynamics(dyn::LinearDynamics, x, us)
return dyn
end

# Export the types of dynamics.
export Dynamics, NonlinearDynamics, LinearDynamics

# Export the functionality each Dynamics requires.
export propagate_dynamics, linearize_dynamics


# Dimensionality helpers.
function xdim(dyn::Dynamics)
return dyn.sys_info.num_x
end

function udim(dyn::Dynamics)
return sum(dyn.sys_info.num_us)
end

function udim(dyn::Dynamics, player_idx)
return dyn.sys_info.num_us[player_idx]
end

export xdim, udim


# TODO(hamzah) Add better tests for the unroll_feedback, unroll_raw_controls functions.
# TODO(hamzah) Abstract the unroll_feedback, unroll_raw_controls functions to not assume linear feedback P.

# Function to unroll a set of feedback matrices from an initial condition.
# Output is a sequence of states xs[:, time] and controls us[player][:, time].
export unroll_feedback
function unroll_feedback(dyn::Dynamics, Ps, x₁)
@assert length(x₁) == xdim(dyn)

N = length(Ps)
@assert N == dyn.sys_info.num_agents

horizon = last(size(first(Ps)))

# Populate state/control trajectory.
xs = zeros(xdim(dyn), horizon)
xs[:, 1] = x₁
us = [zeros(udim(dyn, ii), horizon) for ii in 1:N]
for tt in 2:horizon
for ii in 1:N
us[ii][:, tt - 1] = -Ps[ii][:, :, tt - 1] * xs[:, tt - 1]
end

us_prev = [us[i][:, tt-1] for i in 1:N]
xs[:, tt] = propagate_dynamics(dyn, tt, xs[:, tt-1], us_prev)
end

# Controls at final time.
for ii in 1:N
us[ii][:, horizon] = -Ps[ii][:, :, horizon] * xs[:, horizon]
end

return xs, us
end

# As above, but replacing feedback matrices `P` with raw control inputs `u`.
export unroll_raw_controls
function unroll_raw_controls(dyn::Dynamics, us, x₁)
@assert length(x₁) == xdim(dyn)

N = length(us)
@assert N == dyn.sys_info.num_agents

horizon = last(size(first(us)))

# Populate state trajectory.
xs = zeros(xdim(dyn), horizon)
xs[:, 1] = x₁
us = [zeros(udim(dyn, ii), horizon) for ii in 1:N]
for tt in 2:horizon
us_prev = [us[i][:, tt-1] for i in 1:N]
xs[:, tt] = propagate_dynamics(dyn, tt, xs[:, tt-1], us_prev)
end

return xs
end
5 changes: 3 additions & 2 deletions src/LQNashFeedbackSolver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# A helper function to compute P for all players at time t.
function compute_P_at_t(dyn_at_t::Dynamics, costs_at_t, Zₜ₊₁)
# TODO: Add the QuadraticCost modifier to this costs_at_t argument.
function compute_P_at_t(dyn_at_t::LinearDynamics, costs_at_t, Zₜ₊₁)

num_players = size(costs_at_t)[1]
num_states = xdim(dyn_at_t)
Expand Down Expand Up @@ -35,7 +36,7 @@ end
# Returns feedback matrices P[player][:, :, time]
export solve_lq_nash_feedback
function solve_lq_nash_feedback(
dyn::Dynamics, costs::AbstractArray{Cost}, horizon::Int)
dyn::LinearDynamics, costs::AbstractArray{QuadraticCost}, horizon::Int)

num_players = size(costs)[1]

Expand Down
8 changes: 3 additions & 5 deletions src/LQRFeedbackSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ using LinearAlgebra

# Shorthand function for LTI dynamics and costs.
export solve_lqr_feedback
function solve_lqr_feedback(
dyn::Dynamics, costs::Cost, horizon::Int)
function solve_lqr_feedback(dyn::LinearDynamics, costs::QuadraticCost, horizon::Int)
dyns = [dyn for _ in 1:horizon]
costs = [costs for _ in 1:horizon]
return solve_lqr_feedback(dyns, costs, horizon)
end

# TODO(hamzah): Add interfaces for cases in which one of the arguments is passed in as a list, but the other is not.

export solve_lqr_feedback
function solve_lqr_feedback(dyns::AbstractArray{Dynamics}, costs::AbstractArray{Cost}, horizon::Int)
function solve_lqr_feedback(dyns::AbstractArray{LinearDynamics}, costs::AbstractArray{QuadraticCost}, horizon::Int)

# Ensure the number of dynamics and costs are the same as the horizon.
@assert(ndims(dyns) == 1 && size(dyns, 1) == horizon)
Expand Down Expand Up @@ -52,4 +50,4 @@ function solve_lqr_feedback(dyns::AbstractArray{Dynamics}, costs::AbstractArray{
end

return Ps
end
end
2 changes: 1 addition & 1 deletion src/LQStackelbergFeedbackSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
# Returns feedback matrices P[player][:, :, time]
export solve_lq_stackelberg_feedback
function solve_lq_stackelberg_feedback(
dyn::Dynamics, costs::AbstractArray{Cost}, horizon::Int, leader_idx::Int)
dyn::LinearDynamics, costs::AbstractArray{QuadraticCost}, horizon::Int, leader_idx::Int)

# TODO: Add checks for correct input lengths - they should match the horizon.
num_players = size(costs)[1]
Expand Down
2 changes: 2 additions & 0 deletions src/StackelbergControlHypothesesFiltering.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module StackelbergControlHypothesesFiltering

include("Utils.jl")
include("CostUtils.jl")
include("DynamicsUtils.jl")
include("LQNashFeedbackSolver.jl")
include("LQStackelbergFeedbackSolver.jl")
include("LQRFeedbackSolver.jl")
Expand Down
116 changes: 8 additions & 108 deletions src/Utils.jl
Original file line number Diff line number Diff line change
@@ -1,110 +1,10 @@
# Utilities for Assignment 3. You should not need to modify this file.

# Cost for a single player.
# Form is: x^T_t Q^i x + \sum_j u^{jT}_t R^{ij} u^j_t.
# For simplicity, assuming that Q, R are time-invariant, and that dynamics are
# linear time-invariant, i.e. x_{t+1} = A x_t + \sum_i B^i u^i_t.
mutable struct Cost
Q
Rs # Nonzero R^{ij} for Pi.
end

Cost(Q) = Cost(Q, Dict{Int, Matrix{eltype(Q)}}())
export Cost

# Method to add R^{ij}s to a Cost struct.
export add_control_cost!
function add_control_cost!(c::Cost, other_player_idx, Rij)
c.Rs[other_player_idx] = Rij
end

# Evaluate cost on a state/control trajectory.
# - xs[:, time]
# - us[player][:, time]
export evaluate
function evaluate(c::Cost, xs, us)
horizon = last(size(xs))

total = 0.0
for tt in 1:horizon
total += xs[:, tt]' * c.Q * xs[:, tt]
total += sum(us[jj][:, tt]' * Rij * us[jj][:, tt] for (jj, Rij) in c.Rs)
end

return total
# Utilities
struct SystemInfo
num_agents::Int
num_x::Int
num_us::AbstractArray{Int}
num_v::Int
end
SystemInfo(num_agents, num_x, num_us) = SystemInfo(num_agents, num_x, num_us, 0)

# Dynamics.
export Dynamics
struct Dynamics
A
Bs
end

export xdim
function xdim(dyn::Dynamics)
return first(size(dyn.A))
end

export udim
function udim(dyn::Dynamics)
return sum(last(size(B)) for B in dyn.Bs)
end

function udim(dyn::Dynamics, player_idx)
return last(size(dyn.Bs[player_idx]))
end

# Function to unroll a set of feedback matrices from an initial condition.
# Output is a sequence of states xs[:, time] and controls us[player][:, time].
export unroll_feedback
function unroll_feedback(dyn::Dynamics, Ps, x₁)
@assert length(x₁) == xdim(dyn)

N = length(Ps)
@assert N == length(dyn.Bs)

horizon = last(size(first(Ps)))

# Populate state/control trajectory.
xs = zeros(xdim(dyn), horizon)
xs[:, 1] = x₁
us = [zeros(udim(dyn, ii), horizon) for ii in 1:N]
for tt in 2:horizon
for ii in 1:N
us[ii][:, tt - 1] = -Ps[ii][:, :, tt - 1] * xs[:, tt - 1]
end

xs[:, tt] = dyn.A * xs[:, tt - 1] + sum(
dyn.Bs[ii] * us[ii][:, tt - 1] for ii in 1:N)
end

# Controls at final time.
for ii in 1:N
us[ii][:, horizon] = -Ps[ii][:, :, horizon] * xs[:, horizon]
end

return xs, us
end

# As above, but replacing feedback matrices `P` with raw control inputs `u`.
export unroll_raw_controls
function unroll_raw_controls(dyn::Dynamics, us, x₁)
@assert length(x₁) == xdim(dyn)

N = length(us)
@assert N == length(dyn.Bs)

horizon = last(size(first(us)))

# Populate state trajectory.
xs = zeros(xdim(dyn), horizon)
xs[:, 1] = x₁
us = [zeros(udim(dyn, ii), horizon) for ii in 1:N]
for tt in 2:horizon
xs[:, tt] = dyn.A * xs[:, tt - 1] + sum(
dyn.Bs[ii] * us[ii][:, tt - 1] for ii in 1:N)
end

return xs
end
export SystemInfo
Loading

0 comments on commit b2aab6f

Please sign in to comment.