From 70449302e590eea2d057454edd19ce64370990f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Arnstr=C3=B6m?= <55484604+darnstrom@users.noreply.github.com> Date: Sun, 17 Nov 2024 10:32:41 +0100 Subject: [PATCH] Clearer naming in generate code. Keep track of depth in BST (#8) --- src/codegen.jl | 53 +++++++++++++++++++++++++++++++++++++------------- src/io.jl | 13 +++++-------- src/tree.jl | 11 ++++++++--- src/utils.jl | 14 +++++++++++-- 4 files changed, 65 insertions(+), 26 deletions(-) diff --git a/src/codegen.jl b/src/codegen.jl index d24b5d3..abe7d3f 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -15,7 +15,7 @@ end function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type="float", int_type="unsigned short") isdir(dir) || mkdir(dir) # Get number of outputs - nth,n_out = size(bst.feedbacks[1]).-(1,0) + nth,nz = size(bst.feedbacks[1]).-(1,0) # Concatenate feedbacks into one array feedbacks = reduce(hcat,bst.feedbacks) @@ -27,9 +27,9 @@ function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type= write(fh, "typedef $float_type c_float;\n") write(fh, "typedef $int_type c_int;\n") - write(fh, "#define $(uppercase(fname))_N_PARAM $nth\n") - write(fh, "#define $(uppercase(fname))_N_OUT $n_out\n\n") - write(fh, "void $(fname)_evaluate(c_float* param, c_float* out);\n") + write(fh, "#define $(uppercase(fname))_N_PARAMETER $nth\n") + write(fh, "#define $(uppercase(fname))_N_SOLUTION $nz\n\n") + write(fh, "void $(fname)_evaluate(c_float* parameter, c_float* solution);\n") write(fh, "#endif // ifndef $hguard\n"); close(fh) @@ -43,7 +43,7 @@ function codegen(bst::BinarySearchTree; dir="codegen",fname="pdaqp", float_type= write_array(fsrc,bst.jump_list.-1,fname*"_jump_list","c_int") write(fsrc, """ -void $(fname)_evaluate(c_float* param, c_float* out){ +void $(fname)_evaluate(c_float* parameter, c_float* solution){ int i,j,disp; int id,next_id; c_float val; @@ -51,9 +51,9 @@ void $(fname)_evaluate(c_float* param, c_float* out){ next_id = $(fname)_jump_list[id]; while(next_id != 0){ // Compute halfplane value - disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAM+1); - for(i=0, val=0; i<$(uppercase(fname))_N_PARAM; i++) - val += param[i] * $(fname)_halfplanes[disp++]; + disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAMETER+1); + for(i=0, val=0; i<$(uppercase(fname))_N_PARAMETER; i++) + val += parameter[i] * $(fname)_halfplanes[disp++]; if(val <= $(fname)_halfplanes[disp])// positive branch id = next_id+1; else // negative branch @@ -61,14 +61,41 @@ void $(fname)_evaluate(c_float* param, c_float* out){ next_id = $(fname)_jump_list[id]; } // Leaf node reached -> evaluate affine function - disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAM+1)*$(uppercase(fname))_N_OUT; - for(i=0; i < $(uppercase(fname))_N_OUT; i++){ - for(j=0, val=0; j < $(uppercase(fname))_N_PARAM; j++) - val += param[j] * $(fname)_feedbacks[disp++]; + disp = $(fname)_hp_list[id]*($(uppercase(fname))_N_PARAMETER+1)*$(uppercase(fname))_N_SOLUTION; + for(i=0; i < $(uppercase(fname))_N_SOLUTION; i++){ + for(j=0, val=0; j < $(uppercase(fname))_N_PARAMETER; j++) + val += parameter[j] * $(fname)_feedbacks[disp++]; val += $(fname)_feedbacks[disp++]; - out[i] = val; + solution[i] = val; } } """) close(fsrc) + + # Write simple example + fex = open(joinpath(dir,"example.c"), "w") + write(fex, """ +#include "$(fname).h" +#include + +int main(){ + c_float solution[$nz]; + c_float parameter[$nth]; + int i; + // Initialize parameter + for(i=0; i< $nth; i++) + parameter[i] = 0; + + // Get the solution at the parameter + $(fname)_evaluate(parameter,solution); + + printf("For the parameter\\n"); + for(i=0; i< $nth; i++) + printf("%f\\n",parameter[i]); + printf("the solution is\\n"); + for(i=0; i< $nz; i++) + printf("%f\\n",solution[i]); +} + """) + close(fex) end diff --git a/src/io.jl b/src/io.jl index 2bdc211..fe92b2d 100644 --- a/src/io.jl +++ b/src/io.jl @@ -1,14 +1,11 @@ ## Printing function print_ws(ws,j) print("\r>> #$j\ - |Down: $(length(ws.Sdown))\ - |Up: $(length(ws.Sup))\ - |Fin: $(length(ws.F))| "); + |Pending : $(length(ws.Sdown)+length(ws.Sup))\ + |Finished: $(length(ws.F))| "); end function print_final(ws) - print("\r======= |\ - | Fin: $(length(ws.F)) \ - | # LPs: $(ws.nLPs) \ - | explored: $(length(ws.explored)) \ - ||======= \n"); + printstyled("\r======= \ + Solution with $(length(ws.F)) critical regions \ + ======= \n",color=:light_magenta); end diff --git a/src/tree.jl b/src/tree.jl index 1888c45..7b66548 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -3,16 +3,18 @@ struct BinarySearchTree feedbacks::Vector{Matrix{Float64}} hp_list::Vector{Int} jump_list::Vector{Int} + depth::Int end function isnonempty(A,b) d = DAQP.Model(); - DAQP.settings(d,Dict(:fval_bound=>size(A,1)-1)) # Cannot be outside box + DAQP.settings(d,Dict(:fval_bound=>size(A,1)-1,:zero_tol=>1e-7)) # Cannot be outside box DAQP.setup(d,zeros(0,0),zeros(0),A,b,A_rowmaj=true); x,fval,exitflag,info = DAQP.solve(d); # TODO add guard to do this check return exitflag == 1 end + function get_halfplanes(CRs) nreg = length(CRs) nreg == 0 && return nothing @@ -122,8 +124,11 @@ function build_tree(sol::Solution) U = [N0] get_fbid = s->Set{Int}(fb_ids[collect(s)]) split_objective = x-> max(length.(get_fbid.(x))...) + + depth = 0 while !isempty(U) reg_ids, branches, self_id = pop!(U) + depth = max(depth,length(branches)) hp_ids = reduce(∪,Set(first.(reg2hp[i])) for i in reg_ids); hp_ids = collect(setdiff!(hp_ids,first(b) for b in branches)) @@ -135,7 +140,7 @@ function build_tree(sol::Solution) min_ids = findall(==(min_val),vals) hp_ids = hp_ids[min_ids] if length(branches) > 0 && min_val > 1# Compute the actual split - splits = tuple.(classify_regions(sol.CRs,hps,reg2hp;reg_ids = reg_ids, hp_ids = hp_ids, branches=branches)...) + splits = tuple.(classify_regions(sol.CRs,hps,reg2hp;reg_ids,hp_ids,branches)...) vals =[split_objective(s) for s in splits] min_val,min_id = findmin(vals) hp_id = hp_ids[min_id] # TODO add tie-breaker... @@ -188,7 +193,7 @@ function build_tree(sol::Solution) # Denormalize hps = denormalize(hps,sol.scaling,sol.translation;hps=true) fbs = [denormalize(f,sol.scaling,sol.translation) for f in fbs] - return BinarySearchTree(hps,fbs,hp_list,jump_list) + return BinarySearchTree(hps,fbs,hp_list,jump_list,depth) end function evaluate(bst::BinarySearchTree,θ) diff --git a/src/utils.jl b/src/utils.jl index b3aa13e..584fda4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -228,9 +228,19 @@ end ## Compute AS0 function compute_AS0(mpLDP,Θ) # Center in box is zero -> dtot = d[end,:] - _,_,_,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d[end,:]); + _,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d[end,:]); + if exitflag == 1 + return findall(abs.(info.λ).> 0) + end + # Solve lifted feasibility problem in (x,θ)-space to find initial point + x,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),[-mpLDP.d[1:end-1,:]' mpLDP.M],mpLDP.d[end,:]); + if exitflag != 1 + @warn "There is no parameter that makes the problem feasible" + return nothing + end + θ = x[1:mpLDP.n_theta] + _,_,exitflag,info= DAQP.quadprog(zeros(0,0),zeros(0),mpLDP.M,mpLDP.d'*[θ;1]); return findall(abs.(info.λ).> 0) - # TODO add backup if this fails end ## Get CRs function get_critical_regions(sol::Solution)