Skip to content

Commit

Permalink
feat: support logsoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 5, 2024
1 parent 6aab7f7 commit abc6a9e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
return out ./= tmp
end

function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# if all(isfinite, max_)
@fastmath out .= x .- max_
# else
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
# @. out = ifelse(
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
# )
# end
@fastmath log_ = log.(sum(exp, out; dims))
return out .-= log_
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
) where {T,N}
Expand Down
2 changes: 2 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ end
# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined
Base.abs2(x::TracedRNumber{<:Real}) = x^2

Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T))

struct TypeCast{T<:ReactantPrimitives} <: Function end

(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)
Expand Down

0 comments on commit abc6a9e

Please sign in to comment.