Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Rollback PackageExtensionCompat
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 4, 2023
1 parent e35a13a commit cf3d20c
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 19 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ steps:
julia:
- "1.6"
- "1"
- "1.6"
- "nightly"
adjustments:
- with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
fail-fast: false
matrix:
version:
- "1.6"
- "1"
steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Expand All @@ -32,8 +32,8 @@ ForwardDiff = "0.10"
KernelAbstractions = "0.9"
LuxCUDA = "0.2, 0.3"
NNlib = "0.8, 0.9"
PackageExtensionCompat = "1"
Reexport = "1"
Requires = "1"
ReverseDiff = "1"
Tracker = "0.2"
julia = "1.6"
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LuxLibForwardDiffExt

using ForwardDiff, LuxLib
isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
using LuxLib

function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual})
return ForwardDiff.valtype(eltype(x))
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxLibLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LuxLibLuxCUDAExt

using LuxCUDA, LuxLib
isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA)
using LuxLib
import ChainRulesCore as CRC
import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅

Expand Down
13 changes: 11 additions & 2 deletions ext/LuxLibLuxCUDATrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
module LuxLibLuxCUDATrackerExt

using NNlib, LuxCUDA, LuxLib, Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
if isdefined(Base, :get_extension)
using Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
using LuxCUDA
else
using ..Tracker
import ..Tracker: @grad,
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
using ..LuxCUDA
end
using LuxLib
import LuxLib: AA,
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked

Expand Down
37 changes: 26 additions & 11 deletions ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
module LuxLibReverseDiffExt

using ChainRulesCore, LuxLib, NNlib, ReverseDiff
if isdefined(Base, :get_extension)
using ReverseDiff
import ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules
else
using ..ReverseDiff
import ..ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules
end
using ChainRulesCore, LuxLib
import ChainRulesCore as CRC
import LuxLib: AA, __is_tracked
import ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules

# Patches: Needs upstreaming
@inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i)
Expand Down
11 changes: 9 additions & 2 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
module LuxLibTrackerExt

using NNlib, LuxLib, Tracker
if isdefined(Base, :get_extension)
using Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
else
using ..Tracker
import ..Tracker: @grad,
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
end
using LuxLib
import LuxLib: AA,
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked
import ChainRulesCore as CRC
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal

# NNlib: batched_mul
for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray)
Expand Down
33 changes: 33 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,43 @@ using KernelAbstractions
import KernelAbstractions as KA

# Extensions
#=
using PackageExtensionCompat
function __init__()
@require_extensions
end
=#
if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
# Handling AD Packages
## Handling ForwardDiff
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/LuxLibForwardDiffExt.jl")
end
## Handling Tracker
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/LuxLibTrackerExt.jl")
end
## Handling ReverseDiff
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("../ext/LuxLibReverseDiffExt.jl")
end

# Accelerator Support
## Handling CUDA
@require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin
include("../ext/LuxLibLuxCUDAExt.jl")

@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/LuxLibLuxCUDATrackerExt.jl")
end
end
end
end

include("utils.jl")

Expand Down

0 comments on commit cf3d20c

Please sign in to comment.