From cae15315570106f09fd9ac68638bae961d1b539f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:26:11 -0400 Subject: [PATCH] move branches into InplaceableThunk --- src/rulesets/LinearAlgebra/norm.jl | 58 ++++++++++++++++++------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index abd422387..f511aebe8 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -20,28 +20,36 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real) y = LinearAlgebra.norm(x, p) function norm_pullback_p(Δy) - ∂x = if isempty(x) || p == 0 - InplaceableThunk( - @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))), - identity, - ) + ∂x = InplaceableThunk( + # out-of-place versions + if isempty(x) || p == 0 + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) elseif p == 2 - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), - ) + @thunk(_norm2_back(x, y, Δy)) elseif p == 1 - InplaceableThunk( - @thunk(_norm1_back(x, y, Δy)), - dx -> _norm1_back!(dx, x, y, Δy), - ) + @thunk(_norm1_back(x, y, Δy)) elseif p == Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) elseif p == -Inf - _normInf_back(x, y, Δy) + @thunk(_normInf_back(x, y, Δy)) else - _normp_back_x(x, p, y, Δy) + @thunk(_normp_back_x(x, p, y, Δy)) + end, + # in-place versions + if isempty(x) || p == 0 + identity + elseif p == 2 + dx -> _norm2_back!(dx, x, y, Δy) + elseif p == 1 + dx -> _norm1_back!(dx, x, y, Δy) + elseif p == Inf + dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved + elseif p == -Inf + dx -> dx .+= _normInf_back(x, y, Δy) + else + dx -> dx .+= _normp_back_x(x, p, y, Δy) end + ) ∂p = @thunk _normp_back_p(x, p, y, Δy) return (NO_FIELDS, ∂x, ∂p) end @@ -51,12 +59,18 @@ end function rrule(::typeof(norm), x::AbstractArray{<:Number}) y = LinearAlgebra.norm(x) function norm_pullback_2(Δy) - ∂x = if isempty(x) - zero.(x) .* (zero(y) * zero(real(Δy))) - else - InplaceableThunk( - @thunk(_norm2_back(x, y, Δy)), - dx -> _norm2_back!(dx, x, y, Δy), + ∂x = InplaceableThunk( + if isempty(x) + @thunk(zero.(x) .* (zero(y) * zero(real(Δy)))) + else + @thunk(_norm2_back(x, y, Δy)) + end + , + if isempty(x) + identity + else + dx -> _norm2_back!(dx, x, y, Δy) + end ) end return (NO_FIELDS, ∂x)