From 39b14012966ecc8de57c15924f3454f82d3bb45f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 30 Jun 2024 18:35:02 -0700 Subject: [PATCH 1/9] feat(reactant): compile `single_train_step!` using `Reactant` [skip ci] --- Project.toml | 3 ++ ext/LuxReactantExt/LuxReactantExt.jl | 19 +++++++++++ ext/LuxReactantExt/optimizers.jl | 11 ++++++ ext/LuxReactantExt/train.jl | 50 ++++++++++++++++++++++++++++ ext/LuxReactantExt/utils.jl | 11 ++++++ src/Lux.jl | 12 ++++--- src/helpers/training.jl | 8 +++++ src/transform/reactant.jl | 23 +++++++++++++ test/qa_tests.jl | 3 +- 9 files changed, 134 insertions(+), 6 deletions(-) create mode 100644 ext/LuxReactantExt/LuxReactantExt.jl create mode 100644 ext/LuxReactantExt/optimizers.jl create mode 100644 ext/LuxReactantExt/train.jl create mode 100644 ext/LuxReactantExt/utils.jl create mode 100644 src/transform/reactant.jl diff --git a/Project.toml b/Project.toml index 52263c94f..a3ba12deb 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -60,6 +61,7 @@ LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"] LuxSimpleChainsExt = "SimpleChains" +LuxReactantExt = ["Enzyme", "Reactant"] LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" @@ -96,6 +98,7 @@ NNlib = "0.9.21" Optimisers = "0.3.3" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.2" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl new file mode 100644 index 000000000..838310bd7 --- /dev/null +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -0,0 +1,19 @@ +module LuxReactantExt + +using Adapt: adapt +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated +using Functors: fmapstructure, fmap +using Optimisers: Optimisers, Descent, Leaf +using Random: AbstractRNG, Xoshiro +using Reactant: Reactant +using Lux: Lux, LuxEltypeAdaptor, AutoReactant +using Lux.Experimental: TrainingBackendCache, TrainState +using LuxCore: LuxCore, AbstractExplicitLayer + +include("utils.jl") +include("train.jl") +include("optimizers.jl") + +end diff --git a/ext/LuxReactantExt/optimizers.jl b/ext/LuxReactantExt/optimizers.jl new file mode 100644 index 000000000..c90692a58 --- /dev/null +++ b/ext/LuxReactantExt/optimizers.jl @@ -0,0 +1,11 @@ +function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent}) + @. ps -= leaf.rule.eta * gs +end + +function simple_optimizers_apply!(::Descent, st_opt, ps, gs) + Lux.recursive_map(simple_optimizers_apply!, ps, gs, st_opt) +end + +function simple_optimizers_apply!(opt, st_opt, ps, gs) + throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported.")) +end diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl new file mode 100644 index 000000000..e130e344f --- /dev/null +++ b/ext/LuxReactantExt/train.jl @@ -0,0 +1,50 @@ +function Lux.Experimental.single_train_step!( + ::AutoReactant, obj_fn::F, data, ts::TrainState) where {F} + data = Lux.__make_reactant_array(data) + ps = Lux.__make_reactant_array(ts.parameters) + st = Lux.__make_reactant_array(ts.states) + st_opt = Lux.__make_reactant_array(ts.optimizer_state) + + @info "First call to Reactant. Compiling the training function!" # TODO: Remove + + compiled_grad_and_step! = Reactant.compile( + internal_grad_and_step!, (obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer)) + + loss, st_updated, stats = compiled_grad_and_step!( + obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer) + + cache = TrainingBackendCache{:Reactant, false}(nothing, (; compiled_grad_and_step!)) + ts_new = Lux.Experimental.TrainState( + cache, obj_fn, ts.model, ps, st_updated, ts.optimizer, st_opt, ts.step + 1) + + return nothing, loss, stats, ts_new +end + +function internal_grad_and_step!( + obj_fn::F, model, ps, st, st_opt, data, optimizer) where {F} + dps = Lux.recursive_make_zero(ps) + + _, (loss, st_updated, stats) = Enzyme.autodiff( + Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model), + Duplicated(ps, dps), Const(st), Const(data)) + + simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place + + return loss, st_updated, stats +end + +function Lux.Experimental.single_train_step!(::AutoReactant, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} + data = Lux.__make_reactant_array(data) + + @info "We have already compiled the function!" # TODO: Remove + + loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) + + ts_new = Lux.Experimental.TrainState( + ts.cache, obj_fn, ts.model, ts.parameters, st_updated, + ts.optimizer, ts.optimizer_state, ts.step + 1) + + return nothing, loss, stats, ts_new +end diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl new file mode 100644 index 000000000..44e396af9 --- /dev/null +++ b/ext/LuxReactantExt/utils.jl @@ -0,0 +1,11 @@ +@inline Lux.__make_reactant_array(x::Reactant.RArray) = x +@inline function Lux.__make_reactant_array(x::AbstractArray) + hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) && + return Reactant.ConcreteRArray(x) + return __make_tracer(x) +end +@inline Lux.__make_reactant_array(x) = __make_tracer(x) + +@inline function __make_tracer(x) + return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete) +end diff --git a/src/Lux.jl b/src/Lux.jl index 972b8aa41..c38158fbb 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -52,6 +52,12 @@ include("contrib/contrib.jl") # Pretty Printing include("layers/display.jl") +# Transform to and from other frameworks +include("transform/types.jl") +include("transform/flux.jl") +include("transform/simplechains.jl") +include("transform/reactant.jl") + # Layer Implementations include("layers/basic.jl") include("layers/containers.jl") @@ -75,11 +81,6 @@ include("helpers/size_propagator.jl") include("autodiff/api.jl") include("autodiff/autodiff.jl") -# Transform to and from other frameworks -include("transform/types.jl") -include("transform/flux.jl") -include("transform/simplechains.jl") - # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") @@ -105,6 +106,7 @@ export Training export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoReactant export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 51fdb1a48..b85d4d3ef 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -237,6 +237,10 @@ Perform a single training step. Computes the gradients using [`compute_gradients updates the parameters using [`apply_gradients!`](@ref). All backends supported via [`compute_gradients`](@ref) are supported here. +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + ## Return Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, @@ -253,6 +257,10 @@ Perform a single training step. Computes the gradients using [`compute_gradients updates the parameters using [`apply_gradients`](@ref). All backends supported via [`compute_gradients`](@ref) are supported here. +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + In most cases you should use [`single_train_step!`](@ref) instead of this function. ## Return diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl new file mode 100644 index 000000000..d3903fcc5 --- /dev/null +++ b/src/transform/reactant.jl @@ -0,0 +1,23 @@ +""" + AutoReactant() + +Compile the training loop to MLIR/XLA via `Reactant.jl`. + +This has been added to Lux very recently and is under-going rapid development. Currently, +only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any +issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. +""" +struct AutoReactant end + +""" + __make_reactant_array(x) + +Converts `x` to a `Reactant.ConcreteRArray` if it is not already one. +""" +function __make_reactant_array end + +@inline function __make_reactant_array(nt::NamedTuple{names}) where {names} + return NamedTuple{names}(map(__make_reactant_array, values(nt))) +end +@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t) +@inline __make_reactant_array(x::AbstractExplicitLayer) = x diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 074f464b0..c1e64e45e 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -12,7 +12,8 @@ end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin # Load all trigger packages - import Lux, ComponentArrays, ReverseDiff, SimpleChains, Tracker, Zygote, Enzyme + import Lux, ComponentArrays, ReverseDiff, Flux, SimpleChains, Tracker, Zygote, Enzyme, + Reactant using ExplicitImports # Skip our own packages From 196045ffbf2c1d847a17e1d5707a3c66996e42d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 20:47:14 -0700 Subject: [PATCH 2/9] refactor(reactant): avoid potential name conflict with ADTypes [skip ci] --- ext/LuxReactantExt/LuxReactantExt.jl | 2 +- ext/LuxReactantExt/train.jl | 14 +++++++------- ext/LuxReactantExt/utils.jl | 10 ++++------ src/Lux.jl | 7 +++++-- src/compilers.jl | 15 +++++++++++++++ src/helpers/training.jl | 6 +++--- src/transform/reactant.jl | 23 ----------------------- 7 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 src/compilers.jl delete mode 100644 src/transform/reactant.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 838310bd7..ea5f975a2 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -8,7 +8,7 @@ using Functors: fmapstructure, fmap using Optimisers: Optimisers, Descent, Leaf using Random: AbstractRNG, Xoshiro using Reactant: Reactant -using Lux: Lux, LuxEltypeAdaptor, AutoReactant +using Lux: Lux, LuxEltypeAdaptor, ReactantBackend using Lux.Experimental: TrainingBackendCache, TrainState using LuxCore: LuxCore, AbstractExplicitLayer diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl index e130e344f..b5be41510 100644 --- a/ext/LuxReactantExt/train.jl +++ b/ext/LuxReactantExt/train.jl @@ -1,9 +1,9 @@ function Lux.Experimental.single_train_step!( - ::AutoReactant, obj_fn::F, data, ts::TrainState) where {F} - data = Lux.__make_reactant_array(data) - ps = Lux.__make_reactant_array(ts.parameters) - st = Lux.__make_reactant_array(ts.states) - st_opt = Lux.__make_reactant_array(ts.optimizer_state) + ::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} + data = __make_reactant_array(data) + ps = __make_reactant_array(ts.parameters) + st = __make_reactant_array(ts.states) + st_opt = __make_reactant_array(ts.optimizer_state) @info "First call to Reactant. Compiling the training function!" # TODO: Remove @@ -33,9 +33,9 @@ function internal_grad_and_step!( return loss, st_updated, stats end -function Lux.Experimental.single_train_step!(::AutoReactant, obj_fn::F, data, +function Lux.Experimental.single_train_step!(::ReactantBackend, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} - data = Lux.__make_reactant_array(data) + data = __make_reactant_array(data) @info "We have already compiled the function!" # TODO: Remove diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl index 44e396af9..65fca703a 100644 --- a/ext/LuxReactantExt/utils.jl +++ b/ext/LuxReactantExt/utils.jl @@ -1,11 +1,9 @@ -@inline Lux.__make_reactant_array(x::Reactant.RArray) = x -@inline function Lux.__make_reactant_array(x::AbstractArray) +__make_reactant_array(x::Reactant.RArray) = x +function __make_reactant_array(x::AbstractArray) hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) && return Reactant.ConcreteRArray(x) return __make_tracer(x) end -@inline Lux.__make_reactant_array(x) = __make_tracer(x) +__make_reactant_array(x) = __make_tracer(x) -@inline function __make_tracer(x) - return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete) -end +__make_tracer(x) = Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete) diff --git a/src/Lux.jl b/src/Lux.jl index c38158fbb..d7c00dd7e 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -46,6 +46,9 @@ include("extended_ops.jl") # Training Helpers include("helpers/training.jl") +# Compilers +include("compilers.jl") + # Experimental include("contrib/contrib.jl") @@ -56,7 +59,6 @@ include("layers/display.jl") include("transform/types.jl") include("transform/flux.jl") include("transform/simplechains.jl") -include("transform/reactant.jl") # Layer Implementations include("layers/basic.jl") @@ -106,7 +108,8 @@ export Training export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote -export AutoReactant + +export ReactantBackend export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, diff --git a/src/compilers.jl b/src/compilers.jl new file mode 100644 index 000000000..3901a75ec --- /dev/null +++ b/src/compilers.jl @@ -0,0 +1,15 @@ +abstract type AbstractCompilerBackend end + +""" + ReactantBackend() + +Compile Lux model and gradient computation to MLIR/XLA via `Reactant.jl`. + +This has been added to Lux very recently and is under-going rapid development. Currently, +only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any +issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. + +See [`Lux.Experimental.single_train_step!`](@ref) or +[`Lux.Experimental.single_train_step`](@ref) for information on how to use this backend. +""" +struct ReactantBackend <: AbstractCompilerBackend end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index b85d4d3ef..251cb2894 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -239,14 +239,14 @@ updates the parameters using [`apply_gradients!`](@ref). All backends supported ## Additional Backends - - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + - [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. ## Return Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, only the parameters in `ts` are updated inplace. Users should be using the returned `ts` object for further training steps, else there is no caching and performance will be -suboptimal (and absolutely terrible for backends like `AutoReactant`). +suboptimal (and absolutely terrible for backends like `ReactantBackend`). """ function single_train_step! end @@ -259,7 +259,7 @@ updates the parameters using [`apply_gradients`](@ref). All backends supported v ## Additional Backends - - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + - [`ReactantBackend`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. In most cases you should use [`single_train_step!`](@ref) instead of this function. diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl deleted file mode 100644 index d3903fcc5..000000000 --- a/src/transform/reactant.jl +++ /dev/null @@ -1,23 +0,0 @@ -""" - AutoReactant() - -Compile the training loop to MLIR/XLA via `Reactant.jl`. - -This has been added to Lux very recently and is under-going rapid development. Currently, -only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any -issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. -""" -struct AutoReactant end - -""" - __make_reactant_array(x) - -Converts `x` to a `Reactant.ConcreteRArray` if it is not already one. -""" -function __make_reactant_array end - -@inline function __make_reactant_array(nt::NamedTuple{names}) where {names} - return NamedTuple{names}(map(__make_reactant_array, values(nt))) -end -@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t) -@inline __make_reactant_array(x::AbstractExplicitLayer) = x From b7740f359439150448af0cf25217dbf638bf41bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 20:57:43 -0700 Subject: [PATCH 3/9] refactor(reactant): uniform naming across extensions [skip ci] --- ext/LuxReactantExt/LuxReactantExt.jl | 2 +- ext/LuxReactantExt/{train.jl => training.jl} | 0 ext/LuxReactantExt/utils.jl | 8 +++++++- 3 files changed, 8 insertions(+), 2 deletions(-) rename ext/LuxReactantExt/{train.jl => training.jl} (100%) diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index ea5f975a2..c6a16e2b9 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -13,7 +13,7 @@ using Lux.Experimental: TrainingBackendCache, TrainState using LuxCore: LuxCore, AbstractExplicitLayer include("utils.jl") -include("train.jl") +include("training.jl") include("optimizers.jl") end diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/training.jl similarity index 100% rename from ext/LuxReactantExt/train.jl rename to ext/LuxReactantExt/training.jl diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl index 65fca703a..55156c652 100644 --- a/ext/LuxReactantExt/utils.jl +++ b/ext/LuxReactantExt/utils.jl @@ -4,6 +4,12 @@ function __make_reactant_array(x::AbstractArray) return Reactant.ConcreteRArray(x) return __make_tracer(x) end -__make_reactant_array(x) = __make_tracer(x) +function __make_reactant_array(x) + return Lux.recursive_map(x) do xₗ + hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(xₗ)}) && + return Reactant.ConcreteRArray(xₗ) + return __make_tracer(xₗ) + end +end __make_tracer(x) = Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete) From d151da275e01e536b34a2ae6facb8cbac425f9ef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 21:08:36 -0700 Subject: [PATCH 4/9] feat(training): add inference mode for `TrainState` [skip ci] --- src/helpers/training.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 251cb2894..f77ce3981 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -14,7 +14,7 @@ using LuxCore: LuxCore, AbstractLuxLayer """ TrainState -Training State containing: +## Training State containing: - `model`: `Lux` model. - `parameters`: Trainable Variables of the `model`. @@ -23,7 +23,7 @@ Training State containing: - `optimizer_state`: Optimizer State. - `step`: Number of updates of the parameters made. -Internal fields: +## Internal fields: - `cache`: Cached values. Implementations are free to use this for whatever they want. - `objective_function`: Objective function might be cached. @@ -32,6 +32,12 @@ Internal fields: Constructing this object directly shouldn't be considered a stable API. Use the version with the Optimisers API. + +## Special Features + +To run inference using the current parameters and states simply call the TrainState with +the input data as `tstate(data)`. This will automatically set `Lux.testmode`. However, note +that `tstate.states` will not be updated with the new state. """ @concrete struct TrainState cache @@ -65,6 +71,10 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end +function (tstate::TrainState)(data) + return first(tstate.model(data, tstate.parameters, Lux.testmode(tstate.states))) +end + @concrete struct TrainingBackendCache backend first_try <: StaticBool From 9741de1e6f06571c8c9af5c72757bd6e30ca26c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 21:39:30 -0700 Subject: [PATCH 5/9] feat(reactant): auto compile inference mode if possible [skip ci] --- ext/LuxReactantExt/training.jl | 56 +++++++++++++++++++++++----------- src/compilers.jl | 24 ++++++++++++--- 2 files changed, 58 insertions(+), 22 deletions(-) diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index b5be41510..3b49c34c5 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,11 +1,17 @@ function Lux.Experimental.single_train_step!( - ::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} + backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} data = __make_reactant_array(data) ps = __make_reactant_array(ts.parameters) st = __make_reactant_array(ts.states) st_opt = __make_reactant_array(ts.optimizer_state) - @info "First call to Reactant. Compiling the training function!" # TODO: Remove + compiled_inference = if backend.input_prototype !== nothing + Reactant.compile(first ∘ LuxCore.apply, + (ts.model, __make_reactant_array(backend.input_prototype), + ps, Lux.testmode(st))) + else + nothing + end compiled_grad_and_step! = Reactant.compile( internal_grad_and_step!, (obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer)) @@ -13,13 +19,28 @@ function Lux.Experimental.single_train_step!( loss, st_updated, stats = compiled_grad_and_step!( obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer) - cache = TrainingBackendCache{:Reactant, false}(nothing, (; compiled_grad_and_step!)) + cache = TrainingBackendCache{:Reactant, false}( + nothing, (; compiled_grad_and_step!, compiled_inference)) ts_new = Lux.Experimental.TrainState( cache, obj_fn, ts.model, ps, st_updated, ts.optimizer, st_opt, ts.step + 1) return nothing, loss, stats, ts_new end +function Lux.Experimental.single_train_step!(::ReactantBackend, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} + data = __make_reactant_array(data) + + loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( + obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) + + ts_new = Lux.Experimental.TrainState( + ts.cache, obj_fn, ts.model, ts.parameters, st_updated, + ts.optimizer, ts.optimizer_state, ts.step + 1) + + return nothing, loss, stats, ts_new +end + function internal_grad_and_step!( obj_fn::F, model, ps, st, st_opt, data, optimizer) where {F} dps = Lux.recursive_make_zero(ps) @@ -33,18 +54,19 @@ function internal_grad_and_step!( return loss, st_updated, stats end -function Lux.Experimental.single_train_step!(::ReactantBackend, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} - data = __make_reactant_array(data) - - @info "We have already compiled the function!" # TODO: Remove - - loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( - obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) - - ts_new = Lux.Experimental.TrainState( - ts.cache, obj_fn, ts.model, ts.parameters, st_updated, - ts.optimizer, ts.optimizer_state, ts.step + 1) - - return nothing, loss, stats, ts_new +function (tstate::TrainState{<:TrainingBackendCache{:Reactant}})(data) + data_reactant = __make_reactant_array(data) + compiled_inference = if tstate.cache.extras.compiled_inference !== nothing + tstate.cache.extras.compiled_inference + else + @warn "Inference function not compiled before. This will trigger compilation on \ + every inference call to `(::TrainState)(data)`. Please use \ + `ReactantBackend(; input_prototype = data)` to compile the inference \ + function on the first call to `single_train_step!` or \ + `single_train_step`." maxlog=1 + Reactant.compile(first ∘ LuxCore.apply, + (tstate.model, data_reactant, tstate.parameters, Lux.testmode(tstate.states))) + end + return compiled_inference( + tstate.model, data_reactant, tstate.parameters, Lux.testmode(tstate.states)) end diff --git a/src/compilers.jl b/src/compilers.jl index 3901a75ec..57725f03c 100644 --- a/src/compilers.jl +++ b/src/compilers.jl @@ -1,15 +1,29 @@ abstract type AbstractCompilerBackend end """ - ReactantBackend() + ReactantBackend(; input_prototype = nothing) Compile Lux model and gradient computation to MLIR/XLA via `Reactant.jl`. -This has been added to Lux very recently and is under-going rapid development. Currently, -only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any -issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. +!!! tip "Newly Added Feature!" + + This has been added to Lux very recently and is under-going rapid development. + Currently, only a limited subset of Lux models can be compiled via `Reactant.jl`. If you + encounter any issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub + repository. + +## Keyword Arguments + + - `input_prototype`: Input data representative of the data that will be used for + inference. If this is provided, we will compile the inference function with + `Reactant.jl` on the first call to [`Lux.Experimental.single_train_step!`](@ref) or + [`Lux.Experimental.single_train_step`](@ref). If this is not provided, we will have to + recompile the inference function on every call to `(::TrainState)(data)` and this will + be prohibitively expensive. See [`Lux.Experimental.single_train_step!`](@ref) or [`Lux.Experimental.single_train_step`](@ref) for information on how to use this backend. """ -struct ReactantBackend <: AbstractCompilerBackend end +@kwdef @concrete struct ReactantBackend <: AbstractCompilerBackend + input_prototype = nothing +end From 0ef2b6ac7f432021702331adb3630aaead80e8ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 13:13:12 -0700 Subject: [PATCH 6/9] refactor(reactant): move optimisers into main pkg --- Project.toml | 2 +- ext/LuxReactantExt/LuxReactantExt.jl | 11 ++--------- ext/LuxReactantExt/optimizers.jl | 11 ----------- ext/LuxReactantExt/training.jl | 2 +- src/Lux.jl | 6 ++++++ src/helpers/simple_optimizers.jl | 14 ++++++++++++++ 6 files changed, 24 insertions(+), 22 deletions(-) delete mode 100644 ext/LuxReactantExt/optimizers.jl create mode 100644 src/helpers/simple_optimizers.jl diff --git a/Project.toml b/Project.toml index a3ba12deb..e84802277 100644 --- a/Project.toml +++ b/Project.toml @@ -69,7 +69,7 @@ LuxZygoteExt = "Zygote" ADTypes = "1.5" Adapt = "4" ArgCheck = "2.3" -ArrayInterface = "7.9" +ArrayInterface = "7.10" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.15" diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index c6a16e2b9..64b21b38b 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -1,19 +1,12 @@ module LuxReactantExt -using Adapt: adapt -using ArgCheck: @argcheck -using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated -using Functors: fmapstructure, fmap -using Optimisers: Optimisers, Descent, Leaf -using Random: AbstractRNG, Xoshiro using Reactant: Reactant -using Lux: Lux, LuxEltypeAdaptor, ReactantBackend +using Lux: Lux, ReactantBackend using Lux.Experimental: TrainingBackendCache, TrainState -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore include("utils.jl") include("training.jl") -include("optimizers.jl") end diff --git a/ext/LuxReactantExt/optimizers.jl b/ext/LuxReactantExt/optimizers.jl deleted file mode 100644 index c90692a58..000000000 --- a/ext/LuxReactantExt/optimizers.jl +++ /dev/null @@ -1,11 +0,0 @@ -function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent}) - @. ps -= leaf.rule.eta * gs -end - -function simple_optimizers_apply!(::Descent, st_opt, ps, gs) - Lux.recursive_map(simple_optimizers_apply!, ps, gs, st_opt) -end - -function simple_optimizers_apply!(opt, st_opt, ps, gs) - throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported.")) -end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 3b49c34c5..aa494359a 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -49,7 +49,7 @@ function internal_grad_and_step!( Enzyme.ReverseWithPrimal, obj_fn, Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) - simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place + Lux.simple_optimizers_apply!(optimizer, st_opt, ps, dps) # ps & st_opt are updated in-place return loss, st_updated, stats end diff --git a/src/Lux.jl b/src/Lux.jl index d7c00dd7e..2a845cca4 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -14,7 +14,12 @@ using GPUArraysCore: @allowscalar using LossFunctions: LossFunctions using Markdown: @doc_str using NNlib: NNlib +<<<<<<< HEAD using Optimisers: Optimisers +======= +using Optimisers: Optimisers, Leaf, Descent +using Preferences: load_preference, has_preference +>>>>>>> f68ad624 (refactor(reactant): move optimisers into main pkg) using Random: Random, AbstractRNG using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic using Reexport: Reexport, @reexport @@ -78,6 +83,7 @@ include("helpers/losses.jl") include("helpers/recursive_ops.jl") include("helpers/match_eltype.jl") include("helpers/size_propagator.jl") +include("helpers/simple_optimizers.jl") # AutoDiff include("autodiff/api.jl") diff --git a/src/helpers/simple_optimizers.jl b/src/helpers/simple_optimizers.jl new file mode 100644 index 000000000..cd9a32b08 --- /dev/null +++ b/src/helpers/simple_optimizers.jl @@ -0,0 +1,14 @@ +# These are meant to be used internally for compiling certain lux optiomization +function simple_optimizers_apply!(ps, gs, leaf::Leaf{<:Descent}) + @. ps -= leaf.rule.eta * gs +end + +for opt in (Descent,) + @eval function simple_optimizers_apply!(::$(opt), st_opt, ps, gs) + recursive_map(simple_optimizers_apply!, ps, gs, st_opt) + end +end + +function simple_optimizers_apply!(opt, st_opt, ps, gs) + throw(ArgumentError("Optimizer $(typeof(opt)) not yet supported.")) +end From 2f7aee5b096358313a3270ccb9fbc31fe2b97523 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 13:34:09 -0700 Subject: [PATCH 7/9] fix: tstate inference should return the state --- ext/LuxReactantExt/training.jl | 10 +++++----- src/helpers/training.jl | 4 +--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index aa494359a..6cca06726 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -6,9 +6,9 @@ function Lux.Experimental.single_train_step!( st_opt = __make_reactant_array(ts.optimizer_state) compiled_inference = if backend.input_prototype !== nothing - Reactant.compile(first ∘ LuxCore.apply, + Reactant.compile(LuxCore.apply, (ts.model, __make_reactant_array(backend.input_prototype), - ps, Lux.testmode(st))) + ps, LuxCore.testmode(st))) else nothing end @@ -64,9 +64,9 @@ function (tstate::TrainState{<:TrainingBackendCache{:Reactant}})(data) `ReactantBackend(; input_prototype = data)` to compile the inference \ function on the first call to `single_train_step!` or \ `single_train_step`." maxlog=1 - Reactant.compile(first ∘ LuxCore.apply, - (tstate.model, data_reactant, tstate.parameters, Lux.testmode(tstate.states))) + Reactant.compile(LuxCore.apply, + (tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states))) end return compiled_inference( - tstate.model, data_reactant, tstate.parameters, Lux.testmode(tstate.states)) + tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states)) end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index f77ce3981..6066ad258 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -71,9 +71,7 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end -function (tstate::TrainState)(data) - return first(tstate.model(data, tstate.parameters, Lux.testmode(tstate.states))) -end +(ts::TrainState)(data) = ts.model(data, ts.parameters, Lux.testmode(ts.states)) @concrete struct TrainingBackendCache backend From 70b7a86b83d367c6fc775ea79351d8b57205a4ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jul 2024 13:36:38 -0700 Subject: [PATCH 8/9] chore: format suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/LuxReactantExt/LuxReactantExt.jl | 2 +- ext/LuxReactantExt/training.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 64b21b38b..36f0b5644 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -3,7 +3,7 @@ module LuxReactantExt using Enzyme: Enzyme, Active, Const, Duplicated using Reactant: Reactant using Lux: Lux, ReactantBackend -using Lux.Experimental: TrainingBackendCache, TrainState +using Lux.Training: TrainingBackendCache, TrainState using LuxCore: LuxCore include("utils.jl") diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 6cca06726..1a237cb1e 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,4 +1,4 @@ -function Lux.Experimental.single_train_step!( +function Lux.Training.single_train_step!( backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} data = __make_reactant_array(data) ps = __make_reactant_array(ts.parameters) @@ -21,21 +21,20 @@ function Lux.Experimental.single_train_step!( cache = TrainingBackendCache{:Reactant, false}( nothing, (; compiled_grad_and_step!, compiled_inference)) - ts_new = Lux.Experimental.TrainState( + ts_new = Lux.Training.TrainState( cache, obj_fn, ts.model, ps, st_updated, ts.optimizer, st_opt, ts.step + 1) return nothing, loss, stats, ts_new end -function Lux.Experimental.single_train_step!(::ReactantBackend, obj_fn::F, data, +function Lux.Training.single_train_step!(::ReactantBackend, obj_fn::F, data, ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} data = __make_reactant_array(data) loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) - ts_new = Lux.Experimental.TrainState( - ts.cache, obj_fn, ts.model, ts.parameters, st_updated, + ts_new = Lux.Training.TrainState(ts.cache, obj_fn, ts.model, ts.parameters, st_updated, ts.optimizer, ts.optimizer_state, ts.step + 1) return nothing, loss, stats, ts_new @@ -65,7 +64,8 @@ function (tstate::TrainState{<:TrainingBackendCache{:Reactant}})(data) function on the first call to `single_train_step!` or \ `single_train_step`." maxlog=1 Reactant.compile(LuxCore.apply, - (tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states))) + (tstate.model, data_reactant, tstate.parameters, + LuxCore.testmode(tstate.states))) end return compiled_inference( tstate.model, data_reactant, tstate.parameters, LuxCore.testmode(tstate.states)) From c47a9d0bae785a815a9ff3a67e8481e2f436667d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 Sep 2024 22:44:45 -0400 Subject: [PATCH 9/9] fix: update to new API --- ext/LuxReactantExt/LuxReactantExt.jl | 4 ++- ext/LuxReactantExt/training.jl | 39 ++++++++++++++++------------ ext/LuxReactantExt/utils.jl | 15 ----------- 3 files changed, 25 insertions(+), 33 deletions(-) delete mode 100644 ext/LuxReactantExt/utils.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 36f0b5644..0a72884ef 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,11 +2,13 @@ module LuxReactantExt using Enzyme: Enzyme, Active, Const, Duplicated using Reactant: Reactant +using Static: Static, False +using Setfield: @set! + using Lux: Lux, ReactantBackend using Lux.Training: TrainingBackendCache, TrainState using LuxCore: LuxCore -include("utils.jl") include("training.jl") end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 1a237cb1e..1f4553fe5 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -1,13 +1,13 @@ function Lux.Training.single_train_step!( backend::ReactantBackend, obj_fn::F, data, ts::TrainState) where {F} - data = __make_reactant_array(data) - ps = __make_reactant_array(ts.parameters) - st = __make_reactant_array(ts.states) - st_opt = __make_reactant_array(ts.optimizer_state) + data = Reactant.to_rarray(data) + ps = Reactant.to_rarray(ts.parameters) + st = Reactant.to_rarray(ts.states) + st_opt = Reactant.to_rarray(ts.optimizer_state) compiled_inference = if backend.input_prototype !== nothing Reactant.compile(LuxCore.apply, - (ts.model, __make_reactant_array(backend.input_prototype), + (ts.model, Reactant.to_rarray(backend.input_prototype), ps, LuxCore.testmode(st))) else nothing @@ -19,25 +19,30 @@ function Lux.Training.single_train_step!( loss, st_updated, stats = compiled_grad_and_step!( obj_fn, ts.model, ps, st, st_opt, data, ts.optimizer) - cache = TrainingBackendCache{:Reactant, false}( - nothing, (; compiled_grad_and_step!, compiled_inference)) - ts_new = Lux.Training.TrainState( - cache, obj_fn, ts.model, ps, st_updated, ts.optimizer, st_opt, ts.step + 1) + cache = TrainingBackendCache(backend, False(), nothing, (; compiled_grad_and_step!, + compiled_inference)) + @set! ts.cache = cache + @set! ts.objective_function = obj_fn + @set! ts.parameters = ps + @set! ts.states = st_updated + @set! ts.optimizer_state = st_opt + @set! ts.step = ts.step + 1 - return nothing, loss, stats, ts_new + return nothing, loss, stats, ts # TODO: Return the gradients end function Lux.Training.single_train_step!(::ReactantBackend, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Reactant}, F}) where {F} - data = __make_reactant_array(data) + ts::TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F} + data = Reactant.to_rarray(data) loss, st_updated, stats = ts.cache.extras.compiled_grad_and_step!( obj_fn, ts.model, ts.parameters, ts.states, ts.optimizer_state, data, ts.optimizer) - ts_new = Lux.Training.TrainState(ts.cache, obj_fn, ts.model, ts.parameters, st_updated, - ts.optimizer, ts.optimizer_state, ts.step + 1) + @set! ts.objective_function = obj_fn + @set! ts.states = st_updated + @set! ts.step = ts.step + 1 - return nothing, loss, stats, ts_new + return nothing, loss, stats, ts # TODO: Return the gradients end function internal_grad_and_step!( @@ -53,8 +58,8 @@ function internal_grad_and_step!( return loss, st_updated, stats end -function (tstate::TrainState{<:TrainingBackendCache{:Reactant}})(data) - data_reactant = __make_reactant_array(data) +function (tstate::TrainState{<:TrainingBackendCache{<:ReactantBackend}})(data) + data_reactant = Reactant.to_rarray(data) compiled_inference = if tstate.cache.extras.compiled_inference !== nothing tstate.cache.extras.compiled_inference else diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl deleted file mode 100644 index 55156c652..000000000 --- a/ext/LuxReactantExt/utils.jl +++ /dev/null @@ -1,15 +0,0 @@ -__make_reactant_array(x::Reactant.RArray) = x -function __make_reactant_array(x::AbstractArray) - hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) && - return Reactant.ConcreteRArray(x) - return __make_tracer(x) -end -function __make_reactant_array(x) - return Lux.recursive_map(x) do xₗ - hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(xₗ)}) && - return Reactant.ConcreteRArray(xₗ) - return __make_tracer(xₗ) - end -end - -__make_tracer(x) = Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete)