Skip to content

Commit

Permalink
Merge pull request #186 from Cambridge-ICCS/assert-precs
Browse files Browse the repository at this point in the history
Extend `assert` subroutines to account for different `real` types
  • Loading branch information
jwallwork23 authored Nov 8, 2024
2 parents b0e0311 + 38aa60c commit ff77d68
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 57 deletions.
8 changes: 4 additions & 4 deletions examples/1_SimpleNet/simplenet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ program inference
use ftorch

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_real_array_1d
use ftorch_test_utils, only : assert_allclose

implicit none

! Set working precision for reals
integer, parameter :: wp = sp

integer :: num_args, ix
character(len=128), dimension(:), allocatable :: args

Expand Down Expand Up @@ -55,7 +55,7 @@ program inference

! Check output tensor matches expected value
expected = [0.0, 2.0, 4.0, 6.0, 8.0]
test_pass = assert_real_array_1d(out_data, expected, test_name="SimpleNet", rtol=1e-5)
test_pass = assert_allclose(out_data, expected, test_name="SimpleNet", rtol=1e-5)

! Cleanup
call torch_delete(model)
Expand Down
4 changes: 2 additions & 2 deletions examples/2_ResNet18/resnet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ program inference
use ftorch

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_real
use ftorch_test_utils, only : assert_isclose

implicit none

Expand Down Expand Up @@ -102,7 +102,7 @@ subroutine main()
probability = maxval(probabilities)

! Check top probability matches expected value
test_pass = assert_real(probability, expected_prob, test_name="Check probability", rtol=1e-5)
test_pass = assert_isclose(probability, expected_prob, test_name="Check probability", rtol=1e-5)

write (*,*) "Top result"
write (*,*) ""
Expand Down
8 changes: 4 additions & 4 deletions examples/4_MultiIO/multiionet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ program inference
use ftorch

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_real_array_1d
use ftorch_test_utils, only : assert_allclose

implicit none

! Set working precision for reals
integer, parameter :: wp = sp

Expand Down Expand Up @@ -61,9 +61,9 @@ program inference

! Check output tensors match expected values
expected = [0.0, 2.0, 4.0, 6.0]
test_pass = assert_real_array_1d(out_data1, expected, test_name="MultiIO array1", rtol=1e-5)
test_pass = assert_allclose(out_data1, expected, test_name="MultiIO array1", rtol=1e-5)
expected = [0.0, -3.0, -6.0, -9.0]
test_pass = assert_real_array_1d(out_data2, expected, test_name="MultiIO array2", rtol=1e-5)
test_pass = assert_allclose(out_data2, expected, test_name="MultiIO array2", rtol=1e-5)

! Cleanup
call torch_delete(model)
Expand Down
4 changes: 2 additions & 2 deletions examples/6_Autograd/autograd.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ program example
use ftorch

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_real_array_2d
use ftorch_test_utils, only : assert_allclose

implicit none

Expand Down Expand Up @@ -53,7 +53,7 @@ program example

! Check output tensor matches expected value
expected(:,:) = in_data
test_pass = assert_real_array_2d(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5)
test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5)

! Check that the data match
if (.not. test_pass) then
Expand Down
194 changes: 167 additions & 27 deletions src/ftorch_test_utils.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@

module ftorch_test_utils

use, intrinsic :: iso_fortran_env, only: real32, real64

implicit none

interface assert_real_array
module procedure assert_real_array_1d
module procedure assert_real_array_2d
interface assert_isclose
module procedure assert_isclose_real32
module procedure assert_isclose_real64
end interface

interface assert_allclose
module procedure assert_allclose_real32_1d
module procedure assert_allclose_real32_2d
module procedure assert_allclose_real64_1d
module procedure assert_allclose_real64_2d
end interface

contains
Expand All @@ -33,21 +42,21 @@ subroutine test_print(test_name, message, test_pass)
write(*, '(A, " :: [", A, "] ", A)') report, trim(test_name), trim(message)
end subroutine test_print

!> Asserts that two real values coincide to a given relative tolerance
function assert_real(got, expect, test_name, rtol, print_result) result(test_pass)
!> Asserts that two real32 values coincide to a given relative tolerance
function assert_isclose_real32(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real, intent(in) :: got !! The value to be tested
real, intent(in) :: expect !! The expected value
real, intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
real(kind=real32), intent(in) :: got !! The value to be tested
real(kind=real32), intent(in) :: expect !! The expected value
real(kind=real32), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real :: relative_error
real :: rtol_value
real(kind=real32) :: relative_error
real(kind=real32) :: rtol_value
logical :: print_result_value

if (.not. present(rtol)) then
Expand All @@ -69,23 +78,154 @@ function assert_real(got, expect, test_name, rtol, print_result) result(test_pas
call test_print(test_name, message, test_pass)
end if

end function assert_real
end function assert_isclose_real32

!> Asserts that two real64 values coincide to a given relative tolerance
function assert_isclose_real64(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real(kind=real64), intent(in) :: got !! The value to be tested
real(kind=real64), intent(in) :: expect !! The expected value
real(kind=real64), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real(kind=real64) :: relative_error
real(kind=real64) :: rtol_value
logical :: print_result_value

if (.not. present(rtol)) then
rtol_value = 1e-5
else
rtol_value = rtol
end if

if (.not. present(print_result)) then
print_result_value = .true.
else
print_result_value = print_result
end if

test_pass = (abs(got - expect) <= rtol_value * abs(expect))

if (print_result_value) then
write(message,'("relative tolerance = ", E11.4)') rtol_value
call test_print(test_name, message, test_pass)
end if

end function assert_isclose_real64


!> Asserts that two real32-valued 1D arrays coincide to a given relative tolerance
function assert_allclose_real32_1d(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real(kind=real32), intent(in), dimension(:) :: got !! The array of values to be tested
real(kind=real32), intent(in), dimension(:) :: expect !! The array of expected values
real(kind=real32), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real(kind=real32) :: relative_error
real(kind=real32) :: rtol_value
integer :: shape_error
logical :: print_result_value

if (.not. present(rtol)) then
rtol_value = 1.0e-5
else
rtol_value = rtol
end if

if (.not. present(print_result)) then
print_result_value = .true.
else
print_result_value = print_result
end if

! Check the shapes of the arrays match
shape_error = maxval(abs(shape(got) - shape(expect)))
test_pass = (shape_error == 0)

if (test_pass) then
test_pass = all(abs(got - expect) <= rtol_value * abs(expect))
if (print_result_value) then
write(message,'("relative tolerance = ", E11.4)') rtol_value
call test_print(test_name, message, test_pass)
end if
else if (print_result_value) then
call test_print(test_name, "Arrays have mismatching shapes.", test_pass)
endif

end function assert_allclose_real32_1d

!> Asserts that two real32-valued 2D arrays coincide to a given relative tolerance
function assert_allclose_real32_2d(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real(kind=real32), intent(in), dimension(:,:) :: got !! The array of values to be tested
real(kind=real32), intent(in), dimension(:,:) :: expect !! The array of expected values
real(kind=real32), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real(kind=real32) :: relative_error
real(kind=real32) :: rtol_value
integer :: shape_error
logical :: print_result_value

if (.not. present(rtol)) then
rtol_value = 1.0e-5
else
rtol_value = rtol
end if

if (.not. present(print_result)) then
print_result_value = .true.
else
print_result_value = print_result
end if

! Check the shapes of the arrays match
shape_error = maxval(abs(shape(got) - shape(expect)))
test_pass = (shape_error == 0)

if (test_pass) then
test_pass = all(abs(got - expect) <= rtol_value * abs(expect))
if (print_result_value) then
write(message,'("relative tolerance = ", E11.4)') rtol_value
call test_print(test_name, message, test_pass)
end if
else if (print_result_value) then
call test_print(test_name, "Arrays have mismatching shapes.", test_pass)
endif

end function assert_allclose_real32_2d

!> Asserts that two real-valued 1D arrays coincide to a given relative tolerance
function assert_real_array_1d(got, expect, test_name, rtol, print_result) result(test_pass)
!> Asserts that two real64-valued 1D arrays coincide to a given relative tolerance
function assert_allclose_real64_1d(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real, intent(in), dimension(:) :: got !! The array of values to be tested
real, intent(in), dimension(:) :: expect !! The array of expected values
real, intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
real(kind=real64), intent(in), dimension(:) :: got !! The array of values to be tested
real(kind=real64), intent(in), dimension(:) :: expect !! The array of expected values
real(kind=real64), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real :: relative_error
real :: rtol_value
real(kind=real64) :: relative_error
real(kind=real64) :: rtol_value
integer :: shape_error
logical :: print_result_value

Expand Down Expand Up @@ -115,23 +255,23 @@ function assert_real_array_1d(got, expect, test_name, rtol, print_result) result
call test_print(test_name, "Arrays have mismatching shapes.", test_pass)
endif

end function assert_real_array_1d
end function assert_allclose_real64_1d

!> Asserts that two real-valued 2D arrays coincide to a given relative tolerance
function assert_real_array_2d(got, expect, test_name, rtol, print_result) result(test_pass)
!> Asserts that two real64-valued 2D arrays coincide to a given relative tolerance
function assert_allclose_real64_2d(got, expect, test_name, rtol, print_result) result(test_pass)

character(len=*), intent(in) :: test_name !! Name of the test being run
real, intent(in), dimension(:,:) :: got !! The array of values to be tested
real, intent(in), dimension(:,:) :: expect !! The array of expected values
real, intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
real(kind=real64), intent(in), dimension(:,:) :: got !! The array of values to be tested
real(kind=real64), intent(in), dimension(:,:) :: expect !! The array of expected values
real(kind=real64), intent(in), optional :: rtol !! Optional relative tolerance (defaults to 1e-5)
logical, intent(in), optional :: print_result !! Optionally print test result to screen (defaults to .true.)

logical :: test_pass !! Did the assertion pass?

character(len=80) :: message

real :: relative_error
real :: rtol_value
real(kind=real64) :: relative_error
real(kind=real64) :: rtol_value
integer :: shape_error
logical :: print_result_value

Expand Down Expand Up @@ -161,7 +301,7 @@ function assert_real_array_2d(got, expect, test_name, rtol, print_result) result
call test_print(test_name, "Arrays have mismatching shapes.", test_pass)
endif

end function assert_real_array_2d
end function assert_allclose_real64_2d


end module ftorch_test_utils
Loading

0 comments on commit ff77d68

Please sign in to comment.