diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl index 2580aae..caff556 100644 --- a/src/ReverseDiff.jl +++ b/src/ReverseDiff.jl @@ -23,7 +23,7 @@ using ChainRulesCore # Not all operations will be valid over all of these types, but that's okay; such cases # will simply error when they hit the original operation in the overloaded definition. -const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix) +const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix, :Diagonal) const REAL_TYPES = (:Bool, :Integer, :(Irrational{:ℯ}), :(Irrational{:π}), :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual) const SKIPPED_UNARY_SCALAR_FUNCS = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]