From 7404cd855fdd45cda78c360f3c5a757d33beab68 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 17 Sep 2024 12:05:35 -0400 Subject: [PATCH] Support `enumerate` in conditional generators/comprehensions --- src/lib/array.jl | 1 + test/lib/array.jl | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index 6d914d272..a14e8df86 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -260,6 +260,7 @@ end @adjoint function enumerate(xs) back(::AbstractArray{Nothing}) = nothing back(dy::NamedTuple{(:itr,)}) = tuple(dy.itr) + back(diys::AbstractArray{Union{Nothing, T}}) where T = (map(x -> x === nothing ? x : last(x), diys),) back(diys) = (map(last, diys),) enumerate(xs), back end diff --git a/test/lib/array.jl b/test/lib/array.jl index 8016c9541..08d1250af 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -118,6 +118,19 @@ end sum(v for (_, v) in d) end @test gradient(f_comprehension, w)[1] == ones(5) + + w = [randn(5); NaN] + function f_generator_conditional(w) + d = Dict{Int, Float64}(i => v for (i,v) in enumerate(w) if !isnan(v)) + sum(v for (_, v) in d) + end + @test gradient(f_generator_conditional, w)[1] == [ones(5); nothing] + + function f_comprehension_conditional(w) + d = Dict{Int, Float64}(i => v for (i,v) in enumerate(w) if !isnan(v)) + sum(v for (_, v) in d) + end + @test gradient(f_comprehension_conditional, w)[1] == [ones(5); nothing] end @testset "_reverse" begin