Skip to content

Commit

Permalink
implement epsilon-based floating point number comparisons in tests (#57)
Browse files Browse the repository at this point in the history
* implement eq-epsilon option

* implement support for eq-epsilon in dumbMnistTests and sgdTestCase

* apply eq-epsilon command line option to some test cases

* implement Precision5

UnkindPartition/tasty#337 (comment)

* implement assertClose

* replace Precision5 with assertClose

* make eq-epsilon a global variable using IORef and unsafePerformIO

This is almost identical to the current standard idiom for doing this: https://wiki.haskell.org/Top_level_mutable_state .

* add 2 comments

* implement assertCloseMulti

* replace @?= with assertCloseMulti

* simplify

* implement assertCloseElem

* replace elem with assertCloseElem

* fix msg

* add more comments

* polish up

* implement eqEpsilonDefault; rename assertCloseMulti

* whitespace fix

* implement hlint suggestions

* fix

* move eq-epsilon related stuff to a new file (TestCommonEqEpsilon.hs)

* move even more eq-epsilon related stuff to TestCommonEqEpsilon.hs

* implement support for eq-epsilon

* implement support for eq-epsilon

* whitespace fix

* implement (@?~) infix operator; generalize

* replace (@?=) with (@?~)

* replace assertClose with (@?~)

* fix cmdline syntax

* remove redundant brackets (hlint)

* fix

* fix warning

* plug in HUnit-approx/Test.HUnit.Approx

* define AssertClose class

* implement AssertClose instance for Traversable

* implement AssertClose instance ([a],a)

* relax TestSimpleDescent checks (with eq-epsilon)

* make it more generic

* refactor TestOutdated

* implement AssertClose instance for pairs

1. foldr on pairs does not work as expected: https://stackoverflow.com/questions/72798587/haskell-foldr-foldl-on-pairs
2. performance (skip asList)

* whitespace fix

* relax TestSingleGradient checks (with eq-epsilon)

* apply hlint

* resolve 2x TODO

* implement AssertClose instance for Storable vectors

* get rid of assertion message (not used)

* remove unused imports

Co-authored-by: Stanislaw Findeisen <[email protected]>
  • Loading branch information
sfindeisen and sfindeisen authored Jul 4, 2022
1 parent 060834d commit f81327b
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 60 deletions.
2 changes: 2 additions & 0 deletions horde-ad.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ library testLibrary

-- Modules exported by the library.
exposed-modules: TestCommon
TestCommonEqEpsilon
TestConditionalSynth
TestOutdated

Expand All @@ -203,6 +204,7 @@ library testLibrary
build-depends:
base
, deepseq
, HUnit-approx
, hmatrix
, horde-ad
, ilist
Expand Down
10 changes: 9 additions & 1 deletion test/ExtremelyLongTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ module Main (main) where

import Prelude

import Data.Proxy
import qualified System.IO as SIO
import Test.Tasty
import Test.Tasty.Options
import Test.Tasty.Runners

import TestCommonEqEpsilon
import qualified TestConditionalSynth
import qualified TestOutdated

Expand All @@ -27,7 +31,11 @@ main = do
-- Limit interleaving of characters in parallel tests.
SIO.hSetBuffering SIO.stdout SIO.LineBuffering
SIO.hSetBuffering SIO.stderr SIO.LineBuffering
defaultMain tests
opts <- parseOptions (ingredients : defaultIngredients) tests
setEpsilonEq (lookupOption opts :: EqEpsilon)
defaultMainWithIngredients (ingredients : defaultIngredients) tests
where
ingredients = includingOptions [Option (Proxy :: Proxy EqEpsilon)]

tests :: TestTree
tests = testGroup "Tests" $
Expand Down
10 changes: 9 additions & 1 deletion test/MinimalTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ module Main (main) where

import Prelude

import Data.Proxy
import qualified System.IO as SIO
import Test.Tasty
import Test.Tasty.Options
import Test.Tasty.Runners

import TestCommonEqEpsilon
#if defined(VERSION_ghc_typelits_natnormalise)
import qualified TestSimpleDescent
import qualified TestSingleGradient
Expand All @@ -21,7 +25,11 @@ main = do
-- Limit interleaving of characters in parallel tests.
SIO.hSetBuffering SIO.stdout SIO.LineBuffering
SIO.hSetBuffering SIO.stderr SIO.LineBuffering
defaultMain tests
opts <- parseOptions (ingredients : defaultIngredients) tests
setEpsilonEq (lookupOption opts :: EqEpsilon)
defaultMainWithIngredients (ingredients : defaultIngredients) tests
where
ingredients = includingOptions [Option (Proxy :: Proxy EqEpsilon)]

tests :: TestTree
tests = testGroup "Minimal test that doesn't require any dataset" $
Expand Down
10 changes: 9 additions & 1 deletion test/ShortTestForCI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ module Main (main) where

import Prelude

import Data.Proxy
import qualified System.IO as SIO
import Test.Tasty
import Test.Tasty.Options
import Test.Tasty.Runners

import TestCommonEqEpsilon
#if defined(VERSION_ghc_typelits_natnormalise)
import qualified TestMnistCNN
import qualified TestMnistFCNN
Expand All @@ -24,7 +28,11 @@ main = do
-- Limit interleaving of characters in parallel tests.
SIO.hSetBuffering SIO.stdout SIO.LineBuffering
SIO.hSetBuffering SIO.stderr SIO.LineBuffering
defaultMain tests
opts <- parseOptions (ingredients : defaultIngredients) tests
setEpsilonEq (lookupOption opts :: EqEpsilon)
defaultMainWithIngredients (ingredients : defaultIngredients) tests
where
ingredients = includingOptions [Option (Proxy :: Proxy EqEpsilon)]

tests :: TestTree
tests = testGroup "Short tests for CI" $
Expand Down
124 changes: 124 additions & 0 deletions test/common/TestCommonEqEpsilon.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleInstances, UndecidableInstances #-}

module TestCommonEqEpsilon (EqEpsilon, setEpsilonEq, assertCloseElem, (@?~)) where

import Prelude
import Data.Typeable

import Data.IORef
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Storable as VS
import System.IO.Unsafe
import qualified Test.HUnit.Approx
import Test.Tasty.HUnit
import Test.Tasty.Options

newtype EqEpsilon = EqEpsilon Rational
deriving (Typeable, Num, Fractional)

instance IsOption EqEpsilon where
defaultValue = EqEpsilon eqEpsilonDefault
parseValue s = fmap (EqEpsilon . toRational) ((safeRead :: String -> Maybe Double) s)
optionName = return "eq-epsilon"
optionHelp = return $ "Epsilon to use for floating point comparisons: abs(a-b) < epsilon . Default: " ++ show eqEpsilonDefault

-- Default value for eqEpsilonRef
eqEpsilonDefault :: Rational
eqEpsilonDefault = 1e-6

-- Ugly global epsilon used to compare floating point values.
eqEpsilonRef :: IORef Rational
{-# NOINLINE eqEpsilonRef #-}
eqEpsilonRef = unsafePerformIO $ newIORef eqEpsilonDefault

-- Ugly global epsilon setter (to be called once).
setEpsilonEq :: EqEpsilon -> IO ()
setEpsilonEq (EqEpsilon x) = atomicWriteIORef eqEpsilonRef x

-- | Asserts that the specified actual floating point value is close to the expected value.
-- The output message will contain the prefix, the expected value, and the
-- actual value.
--
-- If the prefix is the empty string (i.e., @\"\"@), then the prefix is omitted
-- and only the expected and actual values are output.
assertClose :: forall a. (Fractional a, Ord a, Show a, HasCallStack)
=> String -- ^ The message prefix
-> a -- ^ The expected value
-> a -- ^ The actual value
-> Assertion
assertClose preface expected actual = do
eqEpsilon <- readIORef eqEpsilonRef
Test.HUnit.Approx.assertApproxEqual preface (fromRational eqEpsilon) expected actual

-- | Asserts that the specified actual floating point value is close to at least one of the expected values.
assertCloseElem :: forall a. (Fractional a, Ord a, Show a, HasCallStack)
=> String -- ^ The message prefix
-> [a] -- ^ The expected values
-> a -- ^ The actual value
-> Assertion
assertCloseElem preface expected actual = do
eqEpsilon <- readIORef eqEpsilonRef
go_assert eqEpsilon expected
where
msg = (if null preface then "" else preface ++ "\n") ++
"wrong result: " ++ show actual ++ " is expected to be a member of " ++ show expected
go_assert :: Rational -> [a] -> Assertion
go_assert _ [] = assertFailure msg
go_assert eqEps (h:t) =
if abs (h-actual) < fromRational eqEps then assertClose msg h actual else go_assert eqEps t

-- | Asserts that the specified actual floating point value list is close to the expected value.
assertCloseList :: forall a. (AssertClose a, HasCallStack)
=> [a] -- ^ The expected value
-> [a] -- ^ The actual value
-> Assertion
assertCloseList expected actual =
go_assert expected actual
where
len1 :: Int = length expected
len2 :: Int = length actual
msgneq :: String = "expected " ++ show len1 ++ " elements, but got " ++ show len2
go_assert :: [a] -> [a] -> Assertion
go_assert [] [] = assertBool "" True
go_assert [] (_:_) = assertFailure msgneq
go_assert (_:_) [] = assertFailure msgneq
go_assert (head_exp:tail_exp) (head_act:tail_act) =
(@?~) head_act head_exp >> go_assert tail_exp tail_act

-- | Foldable to list.
asList :: Foldable t => t a -> [a]
asList = foldr (:) []

-- | Things that can be asserted to be "approximately equal" to each other. The
-- contract for this relation is that it must be reflexive and symmetrical,
-- but not necessarily transitive.
class AssertClose a where
-- | Makes an assertion that the actual value is close to the expected value.
(@?~) :: a -- ^ The actual value
-> a -- ^ The expected value
-> Assertion

instance {-# OVERLAPPABLE #-} (Fractional a, Ord a, Show a) => AssertClose a where
(@?~) :: a -> a -> Assertion
(@?~) actual expected =
assertClose "" expected actual

instance (AssertClose a) => AssertClose (a,a) where
(@?~) :: (a,a) -> (a,a) -> Assertion
(@?~) actual expected =
(@?~) (fst actual) (fst expected) >> (@?~) (snd actual) (snd expected)

instance {-# OVERLAPPABLE #-} (Traversable t, AssertClose a) => AssertClose (t a) where
(@?~) :: t a -> t a -> Assertion
(@?~) actual expected =
assertCloseList (asList expected) (asList actual)

instance {-# OVERLAPPABLE #-} (Traversable t, AssertClose a) => AssertClose (t a, a) where
(@?~) :: (t a, a) -> (t a, a) -> Assertion
(@?~) (actual_xs, actual_x) (expected_xs, expected_x) =
(@?~) actual_x expected_x >> assertCloseList (asList expected_xs) (asList actual_xs)

instance (VS.Storable a, AssertClose a) => AssertClose (VS.Vector a, a) where
(@?~) :: (VS.Vector a, a) -> (VS.Vector a, a) -> Assertion
(@?~) (actual_xs, actual_x) (expected_xs, expected_x) =
(@?~) actual_x expected_x >> assertCloseList (VG.toList expected_xs) (VG.toList actual_xs)
3 changes: 2 additions & 1 deletion test/common/TestConditionalSynth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Test.Tasty.HUnit hiding (assert)

import HordeAd
import HordeAd.Tool.MnistFcnnVector
import TestCommonEqEpsilon

testTrees :: [TestTree]
testTrees = [ conditionalSynthTests
Expand Down Expand Up @@ -212,7 +213,7 @@ gradSmartTestCase prefix lossFunction seedSamples
(initialStateAdam parametersInit)
(_, values) =
unzip $ map (\t -> dReverse 1 (f t) parametersResult) testSamples
(sum values / 100) @?= expected
(sum values / 100) @?~ expected

conditionalSynthTests:: TestTree
conditionalSynthTests = do
Expand Down
5 changes: 3 additions & 2 deletions test/common/TestMnistCNN.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import HordeAd.Tool.MnistCnnShaped
import HordeAd.Tool.MnistData

import TestCommon
import TestCommonEqEpsilon

testTrees :: [TestTree]
testTrees = [ mnistCNNTestsShort
Expand Down Expand Up @@ -194,7 +195,7 @@ convMnistTestCaseCNN prefix epochs maxBatches trainWithLoss testLoss
runEpoch (succ n) res
res <- runEpoch 1 parameters0
let testErrorFinal = 1 - testLoss widthHidden testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected


-- * Another flavour of the simplest possible convolutional net, based on
Expand Down Expand Up @@ -483,7 +484,7 @@ convMnistTestCaseCNNT prefix epochs maxBatches trainWithLoss ftest flen
let testErrorFinal = 1 - ftest proxy_kheight_minus_1 proxy_kwidth_minus_1
proxy_num_hidden proxy_out_channels
testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistCNNTestsLong :: TestTree
mnistCNNTestsLong = testGroup "MNIST CNN long tests"
Expand Down
23 changes: 12 additions & 11 deletions test/common/TestMnistFCNN.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import HordeAd.Tool.MnistFcnnShaped
import HordeAd.Tool.MnistFcnnVector

import TestCommon
import TestCommonEqEpsilon

testTrees :: [TestTree]
testTrees = [ dumbMnistTests
Expand Down Expand Up @@ -78,7 +79,7 @@ sgdTestCase prefix trainDataIO trainWithLoss gamma expected =
trainData <- trainDataIO
sgdShow gamma (trainWithLoss widthHidden widthHidden2)
trainData vec
@?= expected
@?~ expected

mnistTestCase2
:: String
Expand Down Expand Up @@ -135,7 +136,7 @@ mnistTestCase2 prefix epochs maxBatches trainWithLoss widthHidden widthHidden2
res <- runEpoch 1 params0Init
let testErrorFinal =
1 - fcnnMnistTest0 widthHidden widthHidden2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistTestCase2V
:: String
Expand Down Expand Up @@ -201,7 +202,7 @@ mnistTestCase2V prefix epochs maxBatches trainWithLoss widthHidden widthHidden2
res <- runEpoch 1 (params0Init, params1Init)
let testErrorFinal =
1 - fcnnMnistTest1 widthHidden widthHidden2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

fcnnMnistLossTanh :: DualMonad 'DModeGradient Double m
=> Int
Expand Down Expand Up @@ -278,7 +279,7 @@ mnistTestCase2L prefix epochs maxBatches trainWithLoss widthHidden widthHidden2
runEpoch (succ n) res
res <- runEpoch 1 parameters0
let testErrorFinal = 1 - fcnnMnistTest2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistTestCase2T
:: Bool
Expand Down Expand Up @@ -344,7 +345,7 @@ mnistTestCase2T reallyWriteFile
when reallyWriteFile $
writeFile "walltimeLoss.txt" $ unlines $ map ppTime times
let testErrorFinal = 1 - fcnnMnistTest2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistTestCase2D
:: Bool
Expand Down Expand Up @@ -418,7 +419,7 @@ mnistTestCase2D reallyWriteFile miniBatchSize decay
when reallyWriteFile $
writeFile "walltimeLoss.txt" $ unlines $ map ppTime times
let testErrorFinal = 1 - fcnnMnistTest2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistTestCase2F
:: Bool
Expand Down Expand Up @@ -492,7 +493,7 @@ mnistTestCase2F reallyWriteFile miniBatchSize decay
when reallyWriteFile $
writeFile "walltimeLoss.txt" $ unlines $ map ppTime times
let testErrorFinal = 1 - fcnnMnistTest2 testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

mnistTestCase2S
:: forall widthHidden widthHidden2.
Expand Down Expand Up @@ -552,7 +553,7 @@ mnistTestCase2S proxy proxy2
res <- runEpoch 1 parametersInit
let testErrorFinal = 1 - fcnnMnistTestS @widthHidden @widthHidden2
testData res
testErrorFinal @?= expected
testErrorFinal @?~ expected

dumbMnistTests :: TestTree
dumbMnistTests = testGroup "Dumb MNIST tests"
Expand Down Expand Up @@ -619,14 +620,14 @@ dumbMnistTests = testGroup "Dumb MNIST tests"
params0 = V.replicate nParams0 0.1
testData <- loadMnistData testGlyphsPath testLabelsPath
(1 - fcnnMnistTest0 300 100 testData params0)
@?= 0.902
@?~ 0.902
, testCase "fcnnMnistTest2VV on 0.1 params0 300 100 width 10k testset" $ do
let (nParams0, nParams1, _, _) = fcnnMnistLen1 300 100
params0 = V.replicate nParams0 0.1
params1 = V.fromList $ map (`V.replicate` 0.1) nParams1
testData <- loadMnistData testGlyphsPath testLabelsPath
(1 - fcnnMnistTest1 300 100 testData (params0, params1))
@?= 0.902
@?~ 0.902
, testCase "fcnnMnistTest2LL on 0.1 params0 300 100 width 10k testset" $ do
let (nParams0, lParams1, lParams2, _) = fcnnMnistLen2 300 100
vParams1 = V.fromList lParams1
Expand All @@ -637,7 +638,7 @@ dumbMnistTests = testGroup "Dumb MNIST tests"
testData <- loadMnistData testGlyphsPath testLabelsPath
(1 - fcnnMnistTest2 testData
(params0, params1, params2, V.empty))
@?= 0.902
@?~ 0.902
, testProperty "Compare two forward derivatives and gradient for Mnist0" $
\seed seedDs ->
forAll (choose (1, 300)) $ \widthHidden ->
Expand Down
Loading

0 comments on commit f81327b

Please sign in to comment.