Skip to content

Commit

Permalink
update softmax gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 25, 2022
1 parent ced78ee commit 67c24cc
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- '1.3'
- '1.6' # LTS
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.19"
version = "0.2.20"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,7 +23,7 @@ DiffRules = "1.4"
ForwardDiff = "0.10"
LogExpFunctions = "0.3"
MacroTools = "0.5"
NNlib = "0.6, 0.7, 0.8"
NNlib = "0.7.18, 0.8" # 0.7.18 is the last version which supports Julia 1.3
NaNMath = "0.3, 1"
Requires = "0.5, 1.0"
SpecialFunctions = "0.10, 1, 2"
Expand Down
24 changes: 22 additions & 2 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,31 @@ import NNlib: DenseConvDims, DepthwiseConvDims, PoolDims

softmax(xs::TrackedArray; dims=1) = track(softmax, xs; dims=dims)

@grad softmax(xs; dims=1) = softmax(data(xs); dims=dims), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs); dims=dims)),)
if isdefined(NNlib, :∇softmax_data) # use new form to avoid a depwarn, but only possible Julia 1.6+
@eval @grad function softmax(xs; dims=1)
y = softmax(data(xs); dims=dims)
y, Δ -> (nobacksies(:softmax, NNlib.∇softmax_data(data(Δ), data(y); dims=dims)),)
end
else
@eval @grad function softmax(xs; dims=1) # TODO delete this when dropping Julia 1.3 (and increase NNlib bound)
y = softmax(data(xs); dims=dims)
y, Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs), data(y); dims=dims)),)
end
end

logsoftmax(xs::TrackedArray; dims=1) = track(logsoftmax, xs; dims=dims)

@grad logsoftmax(xs; dims=1) = logsoftmax(data(xs); dims=dims), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs); dims=dims)),)
if isdefined(NNlib, :∇logsoftmax_data) # use new form to avoid a depwarn, but only possible Julia 1.6+
@eval @grad function logsoftmax(xs; dims=1)
y = logsoftmax(data(xs); dims=dims)
y, Δ -> (nobacksies(:logsoftmax, NNlib.∇logsoftmax_data(data(Δ), data(y); dims=dims)),)
end
else
@eval @grad function logsoftmax(xs; dims=1)
y = logsoftmax(data(xs); dims=dims)
y, Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs), data(y); dims=dims)),)
end
end

depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
Expand Down
6 changes: 4 additions & 2 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ using Random

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin

@testset "Tracker" begin # overall testset, rest of the file

@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W) -> σ.(W*x), 5, (2,5))
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
Expand Down Expand Up @@ -478,4 +480,4 @@ end
@test size(y) == (5, 3)
end

end #testset
end # overall testset

0 comments on commit 67c24cc

Please sign in to comment.