-
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make frontend type stable #241
Changes from all commits
fc9fefc
a8679a4
a1fbb71
ded0128
20cd125
85d749c
3ff7c5e
d40d099
c1b0b1a
d8e8f53
c38b1f8
ce8be83
eaaa7d2
3a2edd9
e314c38
3311201
20b1f70
7a9f3a9
c39fba5
519b8c8
5d1bebd
95e27e2
9a39c80
74e4def
ded5ee0
e6e9a5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,12 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori | |
throw(ArgumentError("Cuba.jl only supports real-valued integrands")) | ||
# we could support other types by multiplying by the jacobian determinant at the end | ||
|
||
if prob.f isa BatchIntegralFunction | ||
nvec = min(maxiters, prob.f.max_batch) | ||
f = prob.f | ||
prototype = Integrals.get_prototype(prob) | ||
if f isa BatchIntegralFunction | ||
fsize = size(prototype)[begin:(end - 1)] | ||
ncomp = prod(fsize) | ||
nvec = min(maxiters, f.max_batch) | ||
# nvec == 1 in Cuba will change vectors to matrices, so we won't support it when | ||
# batching | ||
nvec > 1 || | ||
|
@@ -33,24 +37,21 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori | |
scale = x -> scale_x!(view(_x, :, 1:size(x, 2)), ub, lb, x) | ||
end | ||
|
||
if isinplace(prob) | ||
fsize = size(prob.f.integrand_prototype)[begin:(end - 1)] | ||
y = similar(prob.f.integrand_prototype, fsize..., nvec) | ||
ax = map(_ -> (:), fsize) | ||
f = function (x, dx) | ||
dy = @view(y[ax..., begin:(begin + size(dx, 2) - 1)]) | ||
prob.f(dy, scale(x), p) | ||
dx .= reshape(dy, :, size(dx, 2)) .* vol | ||
if isinplace(f) | ||
ax = ntuple(_ -> (:), length(fsize)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this type stable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it will be type stable because of constant propagation when |
||
_f = let y_ = similar(prototype, fsize..., nvec) | ||
function (u, _y) | ||
y = @view(y_[ax..., begin:(begin + size(_y, 2) - 1)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why a view? I don't see why this isn't already the right size and just needs the reshape. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I think I take a view from a buffer of size |
||
f(y, scale(u), p) | ||
_y .= reshape(y, size(_y)) .* vol | ||
end | ||
end | ||
else | ||
y = mid isa Number ? prob.f(typeof(mid)[], p) : | ||
prob.f(Matrix{typeof(mid)}(undef, length(mid), 0), p) | ||
fsize = size(y)[begin:(end - 1)] | ||
f = (x, dx) -> dx .= reshape(prob.f(scale(x), p), :, size(dx, 2)) .* vol | ||
_f = (u, y) -> y .= reshape(f(scale(u), p), size(y)) .* vol | ||
end | ||
ncomp = prod(fsize) | ||
else | ||
nvec = 1 | ||
ncomp = length(prototype) | ||
|
||
if mid isa Real | ||
scale = x -> scale_x(ub, lb, only(x)) | ||
|
@@ -59,58 +60,60 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori | |
scale = x -> scale_x!(_x, ub, lb, x) | ||
end | ||
|
||
if isinplace(prob) | ||
y = similar(prob.f.integrand_prototype) | ||
f = (x, dx) -> dx .= vec(prob.f(y, scale(x), p)) .* vol | ||
if isinplace(f) | ||
_f = let y = similar(prototype) | ||
(u, _y) -> begin | ||
f(y, scale(u), p) | ||
_y .= vec(y) .* vol | ||
end | ||
end | ||
else | ||
y = prob.f(mid, p) | ||
f = (x, dx) -> dx .= Iterators.flatten(prob.f(scale(x), p)) .* vol | ||
_f = (u, y) -> y .= Iterators.flatten(f(scale(u), p)) .* vol | ||
end | ||
ncomp = length(y) | ||
end | ||
|
||
if alg isa CubaVegas | ||
out = Cuba.vegas(f, ndim, ncomp; rtol = reltol, | ||
out = if alg isa CubaVegas | ||
Cuba.vegas(_f, ndim, ncomp; rtol = reltol, | ||
atol = abstol, nvec = nvec, | ||
maxevals = maxiters, | ||
flags = alg.flags, seed = alg.seed, minevals = alg.minevals, | ||
nstart = alg.nstart, nincrease = alg.nincrease, | ||
gridno = alg.gridno) | ||
elseif alg isa CubaSUAVE | ||
out = Cuba.suave(f, ndim, ncomp; rtol = reltol, | ||
Cuba.suave(_f, ndim, ncomp; rtol = reltol, | ||
atol = abstol, nvec = nvec, | ||
maxevals = maxiters, | ||
flags = alg.flags, seed = alg.seed, minevals = alg.minevals, | ||
nnew = alg.nnew, nmin = alg.nmin, flatness = alg.flatness) | ||
elseif alg isa CubaDivonne | ||
out = Cuba.divonne(f, ndim, ncomp; rtol = reltol, | ||
Cuba.divonne(_f, ndim, ncomp; rtol = reltol, | ||
atol = abstol, nvec = nvec, | ||
maxevals = maxiters, | ||
flags = alg.flags, seed = alg.seed, minevals = alg.minevals, | ||
key1 = alg.key1, key2 = alg.key2, key3 = alg.key3, | ||
maxpass = alg.maxpass, border = alg.border, | ||
maxchisq = alg.maxchisq, mindeviation = alg.mindeviation) | ||
elseif alg isa CubaCuhre | ||
out = Cuba.cuhre(f, ndim, ncomp; rtol = reltol, | ||
Cuba.cuhre(_f, ndim, ncomp; rtol = reltol, | ||
atol = abstol, nvec = nvec, | ||
maxevals = maxiters, | ||
flags = alg.flags, minevals = alg.minevals, key = alg.key) | ||
end | ||
|
||
# out.integral is a Vector{Float64}, but we want to return it to the shape of the integrand | ||
if prob.f isa BatchIntegralFunction | ||
if y isa AbstractVector | ||
val = out.integral[1] | ||
val = if f isa BatchIntegralFunction | ||
if prototype isa AbstractVector | ||
out.integral[1] | ||
else | ||
val = reshape(out.integral, fsize) | ||
reshape(out.integral, fsize) | ||
end | ||
else | ||
if y isa Real | ||
val = out.integral[1] | ||
elseif y isa AbstractVector | ||
val = out.integral | ||
if prototype isa Real | ||
out.integral[1] | ||
elseif prototype isa AbstractVector | ||
out.integral | ||
else | ||
val = reshape(out.integral, size(y)) | ||
reshape(out.integral, size(prototype)) | ||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,143 +13,141 @@ function Integrals.__solvebp_call(prob::IntegralProblem, | |
mid = (lb + ub) / 2 | ||
|
||
# we get to pick fdim or not based on the IntegralFunction and its output dimensions | ||
y = if prob.f isa BatchIntegralFunction | ||
isinplace(prob.f) ? prob.f.integrand_prototype : | ||
mid isa Number ? prob.f(eltype(mid)[], p) : | ||
prob.f(Matrix{eltype(mid)}(undef, length(mid), 0), p) | ||
else | ||
# we evaluate the oop function to decide whether the output should be vectorized | ||
isinplace(prob.f) ? prob.f.integrand_prototype : prob.f(mid, p) | ||
end | ||
f = prob.f | ||
prototype = Integrals.get_prototype(prob) | ||
|
||
@assert eltype(y)<:Real "Cubature.jl is only compatible with real-valued integrands" | ||
@assert eltype(prototype)<:Real "Cubature.jl is only compatible with real-valued integrands" | ||
|
||
if prob.f isa BatchIntegralFunction | ||
if y isa AbstractVector # this branch could be omitted since the following one should work similarly | ||
if isinplace(prob) | ||
if f isa BatchIntegralFunction | ||
if prototype isa AbstractVector # this branch could be omitted since the following one should work similarly | ||
if isinplace(f) | ||
# dx is a Vector, but we provide the integrand a vector of the same type as | ||
# y, which needs to be resized since the number of batch points changes. | ||
dy = similar(y) | ||
f = (x, dx) -> begin | ||
resize!(dy, length(dx)) | ||
prob.f(dy, x, p) | ||
dx .= dy | ||
_f = let y = similar(prototype) | ||
(u, v) -> begin | ||
resize!(y, length(v)) | ||
f(y, u, p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
v .= y | ||
end | ||
end | ||
else | ||
f = (x, dx) -> (dx .= prob.f(x, p)) | ||
_f = (u, v) -> (v .= f(u, p)) | ||
end | ||
if mid isa Number | ||
if alg isa CubatureJLh | ||
val, err = Cubature.hquadrature_v(f, lb, ub; | ||
val, err = Cubature.hquadrature_v(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
else | ||
val, err = Cubature.pquadrature_v(f, lb, ub; | ||
val, err = Cubature.pquadrature_v(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
end | ||
else | ||
if alg isa CubatureJLh | ||
val, err = Cubature.hcubature_v(f, lb, ub; | ||
val, err = Cubature.hcubature_v(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
else | ||
val, err = Cubature.pcubature_v(f, lb, ub; | ||
val, err = Cubature.pcubature_v(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
end | ||
end | ||
elseif y isa AbstractArray | ||
bfsize = size(y)[begin:(end - 1)] | ||
bfdim = prod(bfsize) | ||
if isinplace(prob) | ||
elseif prototype isa AbstractArray | ||
fsize = size(prototype)[begin:(end - 1)] | ||
fdim = prod(fsize) | ||
if isinplace(f) | ||
# dx is a Matrix, but to provide a buffer of the same type as y, we make | ||
# would like to make views of a larger buffer, but CubatureJL doesn't set | ||
# a hard limit for max_batch, so we allocate a new buffer with the needed size | ||
f = (x, dx) -> begin | ||
dy = similar(y, bfsize..., size(dx, 2)) | ||
prob.f(dy, x, p) | ||
dx .= reshape(dy, bfdim, size(dx, 2)) | ||
_f = let fsize = fsize | ||
(u, v) -> begin | ||
y = similar(prototype, fsize..., size(v, 2)) | ||
f(y, u, p) | ||
v .= reshape(y, fdim, size(v, 2)) | ||
end | ||
end | ||
else | ||
f = (x, dx) -> (dx .= reshape(prob.f(x, p), bfdim, size(dx, 2))) | ||
_f = (u, v) -> (v .= reshape(f(u, p), fdim, size(v, 2))) | ||
end | ||
if mid isa Number | ||
if alg isa CubatureJLh | ||
val_, err = Cubature.hquadrature_v(bfdim, f, lb, ub; | ||
val_, err = Cubature.hquadrature_v(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
else | ||
val_, err = Cubature.pquadrature_v(bfdim, f, lb, ub; | ||
val_, err = Cubature.pquadrature_v(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
end | ||
else | ||
if alg isa CubatureJLh | ||
val_, err = Cubature.hcubature_v(bfdim, f, lb, ub; | ||
val_, err = Cubature.hcubature_v(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
else | ||
val_, err = Cubature.pcubature_v(bfdim, f, lb, ub; | ||
val_, err = Cubature.pcubature_v(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
end | ||
end | ||
val = reshape(val_, bfsize...) | ||
val = reshape(val_, fsize...) | ||
else | ||
error("BatchIntegralFunction integrands must be arrays for Cubature.jl") | ||
end | ||
else | ||
if y isa Real | ||
if prototype isa Real | ||
# no inplace in this case, since the integrand_prototype would be mutable | ||
f = x -> prob.f(x, p) | ||
_f = u -> f(u, p) | ||
if lb isa Number | ||
if alg isa CubatureJLh | ||
val, err = Cubature.hquadrature(f, lb, ub; | ||
val, err = Cubature.hquadrature(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
else | ||
val, err = Cubature.pquadrature(f, lb, ub; | ||
val, err = Cubature.pquadrature(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
end | ||
else | ||
if alg isa CubatureJLh | ||
val, err = Cubature.hcubature(f, lb, ub; | ||
val, err = Cubature.hcubature(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
else | ||
val, err = Cubature.pcubature(f, lb, ub; | ||
val, err = Cubature.pcubature(_f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters) | ||
end | ||
end | ||
elseif y isa AbstractArray | ||
fsize = size(y) | ||
fdim = length(y) | ||
elseif prototype isa AbstractArray | ||
fsize = size(prototype) | ||
fdim = length(prototype) | ||
if isinplace(prob) | ||
dy = similar(y) | ||
f = (x, v) -> (prob.f(dy, x, p); v .= vec(dy)) | ||
_f = let y = similar(prototype) | ||
(u, v) -> (f(y, u, p); v .= vec(y)) | ||
end | ||
else | ||
f = (x, v) -> (v .= vec(prob.f(x, p))) | ||
_f = (u, v) -> (v .= vec(f(u, p))) | ||
end | ||
if mid isa Number | ||
if alg isa CubatureJLh | ||
val_, err = Cubature.hquadrature(fdim, f, lb, ub; | ||
val_, err = Cubature.hquadrature(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
else | ||
val_, err = Cubature.pquadrature(fdim, f, lb, ub; | ||
val_, err = Cubature.pquadrature(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
end | ||
else | ||
if alg isa CubatureJLh | ||
val_, err = Cubature.hcubature(fdim, f, lb, ub; | ||
val_, err = Cubature.hcubature(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
else | ||
val_, err = Cubature.pcubature(fdim, f, lb, ub; | ||
val_, err = Cubature.pcubature(fdim, _f, lb, ub; | ||
reltol = reltol, abstol = abstol, | ||
maxevals = maxiters, error_norm = alg.error_norm) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a standard across all other types because of the interaction with ensembles. Is there a reason to drop it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this is covered i nhttps://github.com//pull/244 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since #135 I think we currently require all kwargs to be one of abstol, reltol, or maxiters. These include the problem kwargs in the other pr. I didn't make this decision, but there is a check for this and it looks like the mechanism for algorithm-specific kwargs is for there to be a place for them in the algorithm constructor. The reason for changing the documentation is then to be consistent with the implementation, since there have been various recent issues related to misleading docs.
When you say the other SciML packages forward extra kwargs to the solvers, do you mean that, for example in
QuadGKJL
the library call looks likequadgk(args... ; current_kwargs..., extra_kwargs...)
? The only disadvantage of this I can think of is that the user can't immediately swap between algorithms when using specialized kwargs. On the other hand, with what we currently have there is an extra maintenance burden of trying to expose all APIs of the solver in the algorithm struct for which we typically have no test coverage. Take for example thebuffer
parameter I added toQuadGKJL
andHCubatureJL
in this pr to cache the heap used internally by each algorithm (thesegbuf/buffer
keyword in their APIs).As long as we are OK reverting #135, I'd be happy to pass the extra kwargs onto the solvers and do it here so that one way or another the documentation is consistent with the implementation.