From 9396fd4e15755ba8bd77352b10bcf4d390b4475f Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Fri, 31 May 2024 15:27:18 -0400 Subject: [PATCH] Remove type piracy with Zygote --- Project.toml | 2 +- src/chainrules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 976eca1..467605f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaylorDiff" uuid = "b36ab563-344f-407b-a36a-4f200bebf99c" authors = ["Songchen Tan "] -version = "0.2.2" +version = "0.2.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" diff --git a/src/chainrules.jl b/src/chainrules.jl index 92b4ed3..fd327c6 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -78,7 +78,7 @@ accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, di TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}} -@adjoint function broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) +@adjoint function broadcasted(::typeof(+), xs::TaylorNumeric...) broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) end