Skip to content

Commit

Permalink
Fix bug in broadcasting SparseAxisArray (#2929)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Mar 23, 2022
1 parent 7691ee2 commit 7fdd4d4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,7 @@ end
function Base.copy(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
) where {N,K}
dict = Dict(
index => bc.f(_get_arg(bc.args, index)...) for
index in _indices(bc.args...)
)
dict = Dict(index => _getindex(bc, index) for index in _indices(bc.args...))
if isempty(dict) && dict isa Dict{Any,Any}
# If `dict` is empty (e.g., because there are no indices), then
# inference will produce a `Dict{Any,Any}`, and we won't have enough
Expand Down Expand Up @@ -230,6 +227,13 @@ function _indices(x::SparseAxisArray, args...)
return indices
end

function _indices(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
args...,
) where {N,K}
return _indices(bc.args...)
end

"""
_get_arg(args::Tuple, index::Tuple)
Expand All @@ -247,6 +251,13 @@ _getindex(x::SparseAxisArray, index) = getindex(x, index...)
_getindex(x::Any, ::Any) = x
_getindex(x::Ref, ::Any) = x[]

function _getindex(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
index,
) where {N,K}
return bc.f(_get_arg(bc.args, index)...)
end

@static if VERSION >= v"1.3"
# `broadcast_preserving_zero_d` calls `axes(A)` which calls `size(A)` which
# is not defined. When at least one argument is a `SparseAxisArray`, we can
Expand Down
21 changes: 21 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,27 @@ function test_expr_index_vars()
return
end

function test_broadcasted_SparseAxisArray_constraint()
model = Model()
@variable(model, u[i = 1:2, i:3])
Containers.@container(c[i = 1:2, i:3], rand())
cons = @constraint(model, u .<= c)
@test cons isa Containers.SparseAxisArray
@test length(cons) == 5
return
end

function test_broadcasted_DenseAxisArray_constraint()
model = Model()
S = 1:2
@variable(model, u[S])
Containers.@container(c[S], rand())
cons = @constraint(model, u .<= c)
@test cons isa Containers.DenseAxisArray
@test length(cons) == 2
return
end

end # module

TestMacros.runtests()

0 comments on commit 7fdd4d4

Please sign in to comment.