Skip to content

Commit

Permalink
test: start adding loss function tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 5, 2024
1 parent f511b37 commit a06b571
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
5 changes: 3 additions & 2 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile
using Reactant: Reactant, @compile, TracedRArray
using Setfield: @set!
using Static: False

using Lux: Lux, Training
using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("overrides.jl")
include("training.jl")

end
Empty file added ext/LuxReactantExt/overrides.jl
Empty file.
44 changes: 44 additions & 0 deletions test/reactant/loss_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@testitem "Compiled Loss Functions" tags=[:reactant] setup=[SharedTestSetup] begin
using Reactant, Lux

rng = StableRNG(123)

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset "xlogx & xlogy" begin
x = rand(rng, 10)
y = rand(rng, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

fn1(x) = LuxOps.xlogx.(x)
fn2(x, y) = LuxOps.xlogy.(x, y)

@test begin
fn1_compiled = @compile fn1(x_ra)
fn1(x) fn1_compiled(x_ra)
end

@test begin
fn2_compiled = @compile fn2(x_ra, y_ra)
fn2(x, y) fn2_compiled(x_ra, y_ra)
end broken=true
end

@testset "Regression Loss" begin end

@testset "Classification Loss" begin end

@testset "Other Losses" begin end
end
end
2 changes: 1 addition & 1 deletion test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "Reactant: Training Tests" tags=[:reactant] setup=[SharedTestSetup] begin
@testitem "Reactant: Training API" tags=[:reactant] setup=[SharedTestSetup] begin
using Reactant

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
Expand Down

0 comments on commit a06b571

Please sign in to comment.