From cf3d20cdb1ac982e5772e4d1bc395b41de0657e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jul 2023 14:41:12 -0400 Subject: [PATCH] Rollback PackageExtensionCompat --- .buildkite/pipeline.yml | 1 + .github/workflows/CI.yml | 1 + Project.toml | 4 ++-- ext/LuxLibForwardDiffExt.jl | 3 ++- ext/LuxLibLuxCUDAExt.jl | 3 ++- ext/LuxLibLuxCUDATrackerExt.jl | 13 ++++++++++-- ext/LuxLibReverseDiffExt.jl | 37 ++++++++++++++++++++++++---------- ext/LuxLibTrackerExt.jl | 11 ++++++++-- src/LuxLib.jl | 33 ++++++++++++++++++++++++++++++ 9 files changed, 87 insertions(+), 19 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2f3f00f9..af5adefc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -22,6 +22,7 @@ steps: julia: - "1.6" - "1" + - "1.6" - "nightly" adjustments: - with: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e91619f2..02ace9c5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,6 +18,7 @@ jobs: fail-fast: false matrix: version: + - "1.6" - "1" steps: - uses: actions/checkout@v3 diff --git a/Project.toml b/Project.toml index b7dadd0b..d5fac92e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -32,8 +32,8 @@ ForwardDiff = "0.10" KernelAbstractions = "0.9" LuxCUDA = "0.2, 0.3" NNlib = "0.8, 0.9" -PackageExtensionCompat = "1" Reexport = "1" +Requires = "1" ReverseDiff = "1" Tracker = "0.2" julia = "1.6" diff --git a/ext/LuxLibForwardDiffExt.jl b/ext/LuxLibForwardDiffExt.jl index 03924f3d..3d25bf06 100644 --- a/ext/LuxLibForwardDiffExt.jl +++ b/ext/LuxLibForwardDiffExt.jl @@ -1,6 +1,7 @@ module LuxLibForwardDiffExt -using ForwardDiff, LuxLib +isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) +using LuxLib function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual}) return ForwardDiff.valtype(eltype(x)) diff --git a/ext/LuxLibLuxCUDAExt.jl b/ext/LuxLibLuxCUDAExt.jl index f6fff767..d5bae7c4 100644 --- a/ext/LuxLibLuxCUDAExt.jl +++ b/ext/LuxLibLuxCUDAExt.jl @@ -1,6 +1,7 @@ module LuxLibLuxCUDAExt -using LuxCUDA, LuxLib +isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA) +using LuxLib import ChainRulesCore as CRC import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅ diff --git a/ext/LuxLibLuxCUDATrackerExt.jl b/ext/LuxLibLuxCUDATrackerExt.jl index 2ad881bb..34edf3de 100644 --- a/ext/LuxLibLuxCUDATrackerExt.jl +++ b/ext/LuxLibLuxCUDATrackerExt.jl @@ -1,7 +1,16 @@ module LuxLibLuxCUDATrackerExt -using NNlib, LuxCUDA, LuxLib, Tracker -import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + using LuxCUDA +else + using ..Tracker + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal + using ..LuxCUDA +end +using LuxLib import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked diff --git a/ext/LuxLibReverseDiffExt.jl b/ext/LuxLibReverseDiffExt.jl index 26491b6f..94620a2b 100644 --- a/ext/LuxLibReverseDiffExt.jl +++ b/ext/LuxLibReverseDiffExt.jl @@ -1,18 +1,33 @@ module LuxLibReverseDiffExt -using ChainRulesCore, LuxLib, NNlib, ReverseDiff +if isdefined(Base, :get_extension) + using ReverseDiff + import ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules +else + using ..ReverseDiff + import ..ReverseDiff: SpecialInstruction, + TrackedArray, + TrackedReal, + decrement_deriv!, + increment_deriv!, + track, + value, + special_reverse_exec!, + special_forward_exec!, + @grad_from_chainrules +end +using ChainRulesCore, LuxLib import ChainRulesCore as CRC import LuxLib: AA, __is_tracked -import ReverseDiff: SpecialInstruction, - TrackedArray, - TrackedReal, - decrement_deriv!, - increment_deriv!, - track, - value, - special_reverse_exec!, - special_forward_exec!, - @grad_from_chainrules # Patches: Needs upstreaming @inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i) diff --git a/ext/LuxLibTrackerExt.jl b/ext/LuxLibTrackerExt.jl index f4c28369..60cf6633 100644 --- a/ext/LuxLibTrackerExt.jl +++ b/ext/LuxLibTrackerExt.jl @@ -1,10 +1,17 @@ module LuxLibTrackerExt -using NNlib, LuxLib, Tracker +if isdefined(Base, :get_extension) + using Tracker + import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +else + using ..Tracker + import ..Tracker: @grad, + data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal +end +using LuxLib import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked import ChainRulesCore as CRC -import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal # NNlib: batched_mul for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 3ac9da33..99d38e55 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -11,10 +11,43 @@ using KernelAbstractions import KernelAbstractions as KA # Extensions +#= using PackageExtensionCompat function __init__() @require_extensions end +=# +if !isdefined(Base, :get_extension) + using Requires +end + +function __init__() + @static if !isdefined(Base, :get_extension) + # Handling AD Packages + ## Handling ForwardDiff + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/LuxLibForwardDiffExt.jl") + end + ## Handling Tracker + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibTrackerExt.jl") + end + ## Handling ReverseDiff + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/LuxLibReverseDiffExt.jl") + end + + # Accelerator Support + ## Handling CUDA + @require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin + include("../ext/LuxLibLuxCUDAExt.jl") + + @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("../ext/LuxLibLuxCUDATrackerExt.jl") + end + end + end +end include("utils.jl")