Skip to content

Commit

Permalink
Clearer naming in generate code. Keep track of depth in BST (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
darnstrom authored Nov 17, 2024
1 parent e5239da commit 7044930
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 26 deletions.
53 changes: 40 additions & 13 deletions src/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type="float", int_type="unsigned short")
isdir(dir) || mkdir(dir)
# Get number of outputs
nth,n_out = size(bst.feedbacks[1]).-(1,0)
nth,nz = size(bst.feedbacks[1]).-(1,0)
# Concatenate feedbacks into one array
feedbacks = reduce(hcat,bst.feedbacks)

Expand All @@ -27,9 +27,9 @@ function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type=

write(fh, "typedef $float_type c_float;\n")
write(fh, "typedef $int_type c_int;\n")
write(fh, "#define $(uppercase(fname))_N_PARAM $nth\n")
write(fh, "#define $(uppercase(fname))_N_OUT $n_out\n\n")
write(fh, "void $(fname)_evaluate(c_float* param, c_float* out);\n")
write(fh, "#define $(uppercase(fname))_N_PARAMETER $nth\n")
write(fh, "#define $(uppercase(fname))_N_SOLUTION $nz\n\n")
write(fh, "void $(fname)_evaluate(c_float* parameter, c_float* solution);\n")
write(fh, "#endif // ifndef $hguard\n");
close(fh)

Expand All @@ -43,32 +43,59 @@ function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type=
write_array(fsrc,bst.jump_list.-1,fname*"_jump_list","c_int")

write(fsrc, """
void $(fname)_evaluate(c_float* param, c_float* out){
void $(fname)_evaluate(c_float* parameter, c_float* solution){
int i,j,disp;
int id,next_id;
c_float val;
id = 0;
next_id = $(fname)_jump_list[id];
while(next_id != 0){
// Compute halfplane value
disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAM+1);
for(i=0, val=0; i<$(uppercase(fname))_N_PARAM; i++)
val += param[i] * $(fname)_halfplanes[disp++];
disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAMETER+1);
for(i=0, val=0; i<$(uppercase(fname))_N_PARAMETER; i++)
val += parameter[i] * $(fname)_halfplanes[disp++];
if(val <= $(fname)_halfplanes[disp])// positive branch
id = next_id+1;
else // negative branch
id = next_id;
next_id = $(fname)_jump_list[id];
}
// Leaf node reached -> evaluate affine function
disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAM+1)*$(uppercase(fname))_N_OUT;
for(i=0; i < $(uppercase(fname))_N_OUT; i++){
for(j=0, val=0; j < $(uppercase(fname))_N_PARAM; j++)
val += param[j] * $(fname)_feedbacks[disp++];
disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAMETER+1)*$(uppercase(fname))_N_SOLUTION;
for(i=0; i < $(uppercase(fname))_N_SOLUTION; i++){
for(j=0, val=0; j < $(uppercase(fname))_N_PARAMETER; j++)
val += parameter[j] * $(fname)_feedbacks[disp++];
val += $(fname)_feedbacks[disp++];
out[i] = val;
solution[i] = val;
}
}
""")
close(fsrc)

# Write simple example
fex = open(joinpath(dir,"example.c"), "w")
write(fex, """
#include "$(fname).h"
#include <stdio.h>
int main(){
c_float solution[$nz];
c_float parameter[$nth];
int i;
// Initialize parameter
for(i=0; i< $nth; i++)
parameter[i] = 0;
// Get the solution at the parameter
$(fname)_evaluate(parameter,solution);
printf("For the parameter\\n");
for(i=0; i< $nth; i++)
printf("%f\\n",parameter[i]);
printf("the solution is\\n");
for(i=0; i< $nz; i++)
printf("%f\\n",solution[i]);
}
""")
close(fex)
end
13 changes: 5 additions & 8 deletions src/io.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
## Printing
function print_ws(ws,j)
print("\r>> #$j\
|Down: $(length(ws.Sdown))\
|Up: $(length(ws.Sup))\
|Fin: $(length(ws.F))| ");
|Pending : $(length(ws.Sdown)+length(ws.Sup))\
|Finished: $(length(ws.F))| ");
end
function print_final(ws)
print("\r======= |\
| Fin: $(length(ws.F)) \
| # LPs: $(ws.nLPs) \
| explored: $(length(ws.explored)) \
||======= \n");
printstyled("\r======= \
Solution with $(length(ws.F)) critical regions \
======= \n",color=:light_magenta);
end
11 changes: 8 additions & 3 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ struct BinarySearchTree
feedbacks::Vector{Matrix{Float64}}
hp_list::Vector{Int}
jump_list::Vector{Int}
depth::Int
end

function isnonempty(A,b)
d = DAQP.Model();
DAQP.settings(d,Dict(:fval_bound=>size(A,1)-1)) # Cannot be outside box
DAQP.settings(d,Dict(:fval_bound=>size(A,1)-1,:zero_tol=>1e-7)) # Cannot be outside box
DAQP.setup(d,zeros(0,0),zeros(0),A,b,A_rowmaj=true);
x,fval,exitflag,info = DAQP.solve(d);
# TODO add guard to do this check
return exitflag == 1
end

function get_halfplanes(CRs)
nreg = length(CRs)
nreg == 0 && return nothing
Expand Down Expand Up @@ -122,8 +124,11 @@ function build_tree(sol::Solution)
U = [N0]
get_fbid = s->Set{Int}(fb_ids[collect(s)])
split_objective = x-> max(length.(get_fbid.(x))...)

depth = 0
while !isempty(U)
reg_ids, branches, self_id = pop!(U)
depth = max(depth,length(branches))

hp_ids = reduce(,Set(first.(reg2hp[i])) for i in reg_ids);
hp_ids = collect(setdiff!(hp_ids,first(b) for b in branches))
Expand All @@ -135,7 +140,7 @@ function build_tree(sol::Solution)
min_ids = findall(==(min_val),vals)
hp_ids = hp_ids[min_ids]
if length(branches) > 0 && min_val > 1# Compute the actual split
splits = tuple.(classify_regions(sol.CRs,hps,reg2hp;reg_ids = reg_ids, hp_ids = hp_ids, branches=branches)...)
splits = tuple.(classify_regions(sol.CRs,hps,reg2hp;reg_ids,hp_ids,branches)...)
vals =[split_objective(s) for s in splits]
min_val,min_id = findmin(vals)
hp_id = hp_ids[min_id] # TODO add tie-breaker...
Expand Down Expand Up @@ -188,7 +193,7 @@ function build_tree(sol::Solution)
# Denormalize
hps = denormalize(hps,sol.scaling,sol.translation;hps=true)
fbs = [denormalize(f,sol.scaling,sol.translation) for f in fbs]
return BinarySearchTree(hps,fbs,hp_list,jump_list)
return BinarySearchTree(hps,fbs,hp_list,jump_list,depth)
end

function evaluate(bst::BinarySearchTree,θ)
Expand Down
14 changes: 12 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,19 @@ end
## Compute AS0
function compute_AS0(mpLDP,Θ)
# Center in box is zero -> dtot = d[end,:]
_,_,_,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d[end,:]);
_,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d[end,:]);
if exitflag == 1
return findall(abs.(info.λ).> 0)
end
# Solve lifted feasibility problem in (x,θ)-space to find initial point
x,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),[-mpLDP.d[1:end-1,:]' mpLDP.M],mpLDP.d[end,:]);
if exitflag != 1
@warn "There is no parameter that makes the problem feasible"
return nothing
end
θ = x[1:mpLDP.n_theta]
_,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d'*[θ;1]);
return findall(abs.(info.λ).> 0)
# TODO add backup if this fails
end
## Get CRs
function get_critical_regions(sol::Solution)
Expand Down

2 comments on commit 7044930

@darnstrom
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119614

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 70449302e590eea2d057454edd19ce64370990f3
git push origin v0.1.0

Please sign in to comment.