From 0c35c52af3aa472211fdbce66ca6ab8558385b3d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 1 Jun 2024 10:51:32 -0700 Subject: [PATCH] Add single_train_step --- Project.toml | 6 ++- docs/src/api/Lux/contrib.md | 2 + ext/LuxReactantExt/LuxReactantExt.jl | 18 +++++++ ext/LuxReactantExt/train.jl | 0 ext/LuxReactantExt/utils.jl | 20 +++++++ src/Lux.jl | 20 ++++--- src/contrib/training.jl | 81 ++++++++++++++++++++++------ src/transform/reactant.jl | 23 ++++++++ 8 files changed, 145 insertions(+), 25 deletions(-) create mode 100644 ext/LuxReactantExt/LuxReactantExt.jl create mode 100644 ext/LuxReactantExt/train.jl create mode 100644 ext/LuxReactantExt/utils.jl create mode 100644 src/transform/reactant.jl diff --git a/Project.toml b/Project.toml index a588fa1ba..efff06afa 100644 --- a/Project.toml +++ b/Project.toml @@ -40,6 +40,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -57,6 +58,7 @@ LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxOptimisersExt = "Optimisers" +LuxReactantExt = ["Enzyme", "Reactant"] LuxReverseDiffExt = "ReverseDiff" LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" @@ -102,6 +104,7 @@ Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4.3" Random = "1.10" +Reactant = "0.1.1" ReTestItems = "1.23.1" Reexport = "1.2.2" ReverseDiff = "1.15" @@ -134,6 +137,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" @@ -144,4 +148,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "Flux", "ForwardDiff", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Reactant", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] \ No newline at end of file diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index e0ce8f01d..20441b598 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -36,6 +36,8 @@ Lux.Experimental.TrainState Lux.Experimental.compute_gradients Lux.Experimental.apply_gradients Lux.Experimental.apply_gradients! +Lux.Experimental.single_train_step +Lux.Experimental.single_train_step! ``` ## Parameter Freezing diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl new file mode 100644 index 000000000..ea93c67d8 --- /dev/null +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -0,0 +1,18 @@ +module LuxReactantExt + +using Adapt: adapt +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated +using Functors: fmapstructure, fmap +using Random: AbstractRNG, Xoshiro +using Reactant: Reactant +using Lux: Lux, LuxEltypeAdaptor, AutoReactant +using LuxCore: LuxCore, AbstractExplicitLayer + +include("utils.jl") + +# compile the entire training loop +include("train.jl") + +end diff --git a/ext/LuxReactantExt/train.jl b/ext/LuxReactantExt/train.jl new file mode 100644 index 000000000..e69de29bb diff --git a/ext/LuxReactantExt/utils.jl b/ext/LuxReactantExt/utils.jl new file mode 100644 index 000000000..61f66ffec --- /dev/null +++ b/ext/LuxReactantExt/utils.jl @@ -0,0 +1,20 @@ +@inline Lux.__make_reactant_array(x::Reactant.RArray) = x +@inline function Lux.__make_reactant_array(x::AbstractArray) + hasmethod(Reactant.ArrayToConcrete, Tuple{typeof(x)}) && + return Reactant.ConcreteRArray(x) + return __make_tracer(x) +end +@inline Lux.__make_reactant_array(x) = __make_tracer(x) + +@inline function __make_tracer(x) + return Reactant.make_tracer(IdDict(), x, (), Reactant.ArrayToConcrete, nothing) +end + +@inline function __try_similar_structure(x::AbstractArray, y::NamedTuple{()}) + length(x) == 0 && return y + throw(DimensionMismatch(lazy"Expected empty array, got $(size(x)).")) +end +@inline function __try_similar_structure(x::AbstractArray, y::AbstractArray) + return parent(x) !== x ? copy(x) : x # unview arrays and such +end +@inline __try_similar_structure(x, y) = fmap(__try_similar_structure, x, y) diff --git a/src/Lux.jl b/src/Lux.jl index 92695f60e..e2cb13557 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -3,7 +3,8 @@ module Lux using PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes: AbstractADType, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote + using ADTypes: AbstractADType, AutoEnzyme, AutoForwardDiff, AutoReverseDiff, + AutoTracker, AutoZygote using Adapt: Adapt, adapt using ArgCheck: @argcheck using ArrayInterface: ArrayInterface @@ -16,7 +17,7 @@ using PrecompileTools: @recompile_invalidations using Markdown: @doc_str using OhMyThreads: tmapreduce using Preferences: @load_preference - using Random: Random, AbstractRNG + using Random: Random, AbstractRNG, Xoshiro using Reexport: @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -48,6 +49,12 @@ const DISABLE_AUTOMATIC_NESTED_AD_SWITCH = @load_preference("DisableAutomaticNes # Utilities include("utils.jl") +# Transform to and from other frameworks +include("transform/types.jl") +include("transform/flux.jl") +include("transform/simplechains.jl") +include("transform/reactant.jl") + # Layer Implementations include("layers/basic.jl") include("layers/containers.jl") @@ -72,11 +79,6 @@ include("helpers/compact.jl") include("helpers/autodiff.jl") include("helpers/nested_ad.jl") -# Transform to and from other frameworks -include("transform/types.jl") -include("transform/flux.jl") -include("transform/simplechains.jl") - # Distributed Training include("distributed/backend.jl") include("distributed/public_api.jl") @@ -103,13 +105,15 @@ export @compact, CompactLuxLayer export jacobian_vector_product, vector_jacobian_product export batched_jacobian -export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export AutoReactant export f16, f32, f64 export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer +export ToReactantAdaptor, ReactantLayer export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/contrib/training.jl b/src/contrib/training.jl index c0f285992..95caa6890 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -13,6 +13,11 @@ Internal fields: - `cache`: Cached values. Implementations are free to use this for whatever they want. - `objective_function`: Objective function might be cached. + +!!! warning + + Constructing this object directly shouldn't be considered a stable API. Use the + version with the Optimisers API. """ @concrete struct TrainState{C, F} cache::C @@ -27,8 +32,8 @@ end function Base.show(io::IO, ts::TrainState) println(io, "TrainState") println(io, " model: ", ts.model) - println(io, " parameters: ", Lux.parameterlength(ts.parameters)) - println(io, " states: ", Lux.statelength(ts.states)) + println(io, " # of parameters: ", Lux.parameterlength(ts.parameters)) + println(io, " # of states: ", Lux.statelength(ts.states)) println(io, " optimizer_state: ", ts.optimizer_state) print(io, " step: ", ts.step) ts.cache !== nothing && print(io, "\n cache: ", nameof(typeof(ts.cache))) @@ -36,21 +41,24 @@ function Base.show(io::IO, ts::TrainState) print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end -""" - apply_gradients(ts::TrainState, grads) - -Update the parameters stored in `ts` using the gradients `grads`. - +const APPLY_GRAD_DOCSTRING = """ ## Arguments - `ts`: [`TrainState`](@ref) object. - `grads`: Gradients of the loss function wrt `ts.params`. - - `update_inplace`: Whether to update the parameters inplace or not. ## Returns Updated [`TrainState`](@ref) object. """ + +""" + apply_gradients(ts::TrainState, grads) + +Update the parameters stored in `ts` using the gradients `grads`. + +$(APPLY_GRAD_DOCSTRING) +""" function apply_gradients end """ @@ -59,14 +67,7 @@ function apply_gradients end Update the parameters stored in `ts` using the gradients `grads`. This is an inplace version of [`apply_gradients`](@ref). -## Arguments - - - `ts`: [`TrainState`](@ref) object. - - `grads`: Gradients of the loss function wrt `ts.params`. - -## Returns - -Updated [`TrainState`](@ref) object. +$(APPLY_GRAD_DOCSTRING) """ function apply_gradients! end @@ -146,3 +147,51 @@ end return wrapped_objective_function, st_updated, stats end + +""" + single_train_step!(backend, obj_fn::F, data, ts::TrainState) + +Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and +updates the parameters using [`apply_gradients!`](@ref). All backends supported via +[`compute_gradients`](@ref) are supported here. + +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + +## Return + +Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, +only the parameters in `ts` are updated inplace. Users should be using the returned `ts` +object for further training steps, else there is no caching and performance will be +suboptimal (and absolutely terrible for backends like `AutoReactant`). +""" +function single_train_step! end + +""" + single_train_step(backend, obj_fn::F, data, ts::TrainState) + +Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and +updates the parameters using [`apply_gradients`](@ref). All backends supported via +[`compute_gradients`](@ref) are supported here. + +## Additional Backends + + - [`AutoReactant`](@ref): Compiles the training loop to MLIR/XLA via `Reactant.jl`. + +In most cases you should use [`single_train_step!`](@ref) instead of this function. + +## Return + +Returned values are the same as [`compute_gradients`](@ref). +""" +function single_train_step end + +for inplace in ("!", "") + step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace) + @eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F} + grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts) + ts = $apply_fn(ts, grads) + return grads, loss, stats, ts + end +end diff --git a/src/transform/reactant.jl b/src/transform/reactant.jl new file mode 100644 index 000000000..d3903fcc5 --- /dev/null +++ b/src/transform/reactant.jl @@ -0,0 +1,23 @@ +""" + AutoReactant() + +Compile the training loop to MLIR/XLA via `Reactant.jl`. + +This has been added to Lux very recently and is under-going rapid development. Currently, +only a limited subset of Lux models can be compiled via `Reactant.jl`. If you encounter any +issues, please report them on the `Lux.jl` or `Reactant.jl` GitHub repository. +""" +struct AutoReactant end + +""" + __make_reactant_array(x) + +Converts `x` to a `Reactant.ConcreteRArray` if it is not already one. +""" +function __make_reactant_array end + +@inline function __make_reactant_array(nt::NamedTuple{names}) where {names} + return NamedTuple{names}(map(__make_reactant_array, values(nt))) +end +@inline __make_reactant_array(t::Tuple) = map(__make_reactant_array, t) +@inline __make_reactant_array(x::AbstractExplicitLayer) = x