Skip to content

Commit

Permalink
tests: Fix compiling and running tests by using AUnit's Test_Case
Browse files Browse the repository at this point in the history
  • Loading branch information
onox committed Apr 13, 2024
1 parent b64a874 commit c13af68
Show file tree
Hide file tree
Showing 17 changed files with 343 additions and 351 deletions.
2 changes: 1 addition & 1 deletion tests/gnat.adc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pragma Restrictions (No_Reentrancy);

pragma Restrictions (No_Obsolescent_Features);
pragma Restrictions (No_Implementation_Aspect_Specifications);
pragma Restrictions (No_Implementation_Attributes);
-- pragma Restrictions (No_Implementation_Attributes);
pragma Restrictions (No_Implementation_Identifiers);
-- pragma Restrictions (No_Implementation_Pragmas);
pragma Restrictions (No_Implementation_Units);
Expand Down
148 changes: 60 additions & 88 deletions tests/src/generic_test_tensors_matrices.adb
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ with Orka.Resources.Locations.Directories;

package body Generic_Test_Tensors_Matrices is

overriding
function Name (Object : Test_Case) return AUnit.Test_String is (AUnit.Format ("(Tensors - " & Suite_Name & " - Matrices)"));

use Tensors;
use type Tensors.Element;

subtype CPU_Tensor is Tensor_Type;
subtype CPU_QR_Factorization is QR_Factorization_Type;
subtype Test is AUnit.Test_Cases.Test_Case'Class;

use AUnit.Assertions;

Expand Down Expand Up @@ -850,104 +854,72 @@ package body Generic_Test_Tensors_Matrices is

----------------------------------------------------------------------------

package Caller is new AUnit.Test_Caller (Test);

Test_Suite : aliased AUnit.Test_Suites.Test_Suite;

function Suite return AUnit.Test_Suites.Access_Test_Suite is
Name : constant String := "(Tensors - " & Suite_Name & " - Matrices) ";
begin
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Flatten", Test_Flatten'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Identity (square)", Test_Identity_Square'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Identity (not square)", Test_Identity_Not_Square'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Reshape", Test_Reshape'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Concatenate", Test_Concatenate'Access));

Test_Suite.Add_Test (Caller.Create
(Name & "Test function Main_Diagonal", Test_Main_Diagonal'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Diagonal", Test_Diagonal'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Trace", Test_Trace'Access));

Test_Suite.Add_Test (Caller.Create
(Name & "Test indexing row using index", Test_Constant_Indexing_Index_Row'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test indexing value using index", Test_Constant_Indexing_Index_Value'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test indexing value using index (boolean)",
Test_Constant_Indexing_Index_Boolean'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test indexing using range", Test_Constant_Indexing_Range'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test indexing using tensor", Test_Constant_Indexing_Tensor'Access));

Test_Suite.Add_Test (Caller.Create
(Name & "Test set row using index", Test_Set_Value_Index_Row'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test set row using range", Test_Set_Value_Index_Range'Access));

Test_Suite.Add_Test (Caller.Create
(Name & "Test '*' operator (inner product)", Test_Operator_Multiply_Inner'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test '**'", Test_Operator_Power'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Outer (outer product)", Test_Outer'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Inverse (invertible)", Test_Inverse_Invertible'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Inverse (singular)", Test_Inverse_Singular'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Transpose", Test_Transpose'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Solve", Test_Solve'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Solve (triangular)", Test_Solve_Triangular'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Divide_By", Test_Divide_By'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Upper_Triangular", Test_Upper_Triangular'Access));

Test_Suite.Add_Test (Caller.Create
(Name & "Test function QR", Test_QR'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Cholesky", Test_Cholesky'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Cholesky_Update (downdate)", Test_Cholesky_Downdate'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Least_Squares (shapes)", Test_Shapes_Least_Squares'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Least_Squares (values)", Test_Values_Least_Squares'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function Constrained_Least_Squares", Test_Constrained_Least_Squares'Access));
overriding
procedure Register_Tests (Object : in out Test_Case) is
procedure Register_Routine (Name : String; Pointer : AUnit.Test_Cases.Test_Routine) is
begin
AUnit.Test_Cases.Registration.Register_Routine (Object, Pointer, Name);
end Register_Routine;
begin
Register_Routine ("Test function Flatten", Test_Flatten'Unrestricted_Access);
Register_Routine ("Test function Identity (square)", Test_Identity_Square'Unrestricted_Access);
Register_Routine ("Test function Identity (not square)", Test_Identity_Not_Square'Unrestricted_Access);
Register_Routine ("Test function Reshape", Test_Reshape'Unrestricted_Access);
Register_Routine ("Test function Concatenate", Test_Concatenate'Unrestricted_Access);

Register_Routine ("Test function Main_Diagonal", Test_Main_Diagonal'Unrestricted_Access);
Register_Routine ("Test function Diagonal", Test_Diagonal'Unrestricted_Access);
Register_Routine ("Test function Trace", Test_Trace'Unrestricted_Access);

Register_Routine ("Test indexing row using index", Test_Constant_Indexing_Index_Row'Unrestricted_Access);
Register_Routine ("Test indexing value using index", Test_Constant_Indexing_Index_Value'Unrestricted_Access);
Register_Routine ("Test indexing value using index (boolean)",
Test_Constant_Indexing_Index_Boolean'Unrestricted_Access);
Register_Routine ("Test indexing using range", Test_Constant_Indexing_Range'Unrestricted_Access);
Register_Routine ("Test indexing using tensor", Test_Constant_Indexing_Tensor'Unrestricted_Access);

Register_Routine ("Test set row using index", Test_Set_Value_Index_Row'Unrestricted_Access);
Register_Routine ("Test set row using range", Test_Set_Value_Index_Range'Unrestricted_Access);

Register_Routine ("Test '*' operator (inner product)", Test_Operator_Multiply_Inner'Unrestricted_Access);
Register_Routine ("Test '**'", Test_Operator_Power'Unrestricted_Access);
Register_Routine ("Test function Outer (outer product)", Test_Outer'Unrestricted_Access);
Register_Routine ("Test function Inverse (invertible)", Test_Inverse_Invertible'Unrestricted_Access);
Register_Routine ("Test function Inverse (singular)", Test_Inverse_Singular'Unrestricted_Access);
Register_Routine ("Test function Transpose", Test_Transpose'Unrestricted_Access);
Register_Routine ("Test function Solve", Test_Solve'Unrestricted_Access);
Register_Routine ("Test function Solve (triangular)", Test_Solve_Triangular'Unrestricted_Access);
Register_Routine ("Test function Divide_By", Test_Divide_By'Unrestricted_Access);
Register_Routine ("Test function Upper_Triangular", Test_Upper_Triangular'Unrestricted_Access);

Register_Routine ("Test function QR", Test_QR'Unrestricted_Access);
Register_Routine ("Test function Cholesky", Test_Cholesky'Unrestricted_Access);
Register_Routine ("Test function Cholesky_Update (downdate)", Test_Cholesky_Downdate'Unrestricted_Access);
Register_Routine ("Test function Least_Squares (shapes)", Test_Shapes_Least_Squares'Unrestricted_Access);
Register_Routine ("Test function Least_Squares (values)", Test_Values_Least_Squares'Unrestricted_Access);
Register_Routine ("Test function Constrained_Least_Squares", Test_Constrained_Least_Squares'Unrestricted_Access);

-- TODO Statistics: Min, Max, Quantile, Median, Mean, Variance (with Axis parameter)

Test_Suite.Add_Test (Caller.Create
(Name & "Test function Any_True", Test_Any_True'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test function All_True", Test_All_True'Access));
Register_Routine ("Test function Any_True", Test_Any_True'Unrestricted_Access);
Register_Routine ("Test function All_True", Test_All_True'Unrestricted_Access);

-- Expressions
Test_Suite.Add_Test (Caller.Create
(Name & "Test reduction binary operator", Test_Reduction_Binary_Operator'Access));
Test_Suite.Add_Test (Caller.Create
(Name & "Test reduction number", Test_Reduction_Number'Access));
Register_Routine ("Test reduction binary operator", Test_Reduction_Binary_Operator'Unrestricted_Access);
Register_Routine ("Test reduction number", Test_Reduction_Number'Unrestricted_Access);

-- TODO Cumulative, Reduce (with Axis parameter)
end Register_Tests;

Test_Suite : aliased AUnit.Test_Suites.Test_Suite;

Test_1 : aliased Test_Case;

function Suite return AUnit.Test_Suites.Access_Test_Suite is
begin
Test_Suite.Add_Test (Test_1'Access);
return Test_Suite'Access;
end Suite;

use Orka.Resources.Locations.Directories;
begin
Initialize_Shaders
(Prefix_Sum => Create_Location ("../orka/data/shaders"),
Tensors_GPU => Create_Location ("../orka_tensors_gpu/data/shaders"));
Orka.Logging.Set_Logger (Orka.Loggers.Terminal.Create_Logger (Level => Orka.Loggers.Info));
end Generic_Test_Tensors_Matrices;
18 changes: 12 additions & 6 deletions tests/src/generic_test_tensors_matrices.ads
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,37 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.

with AUnit.Test_Fixtures;
with AUnit.Test_Cases;
with AUnit.Test_Suites;

with Orka.Numerics.Tensors;
with Orka.Resources.Locations;

generic
Suite_Name : String;

type Test is new AUnit.Test_Fixtures.Test_Fixture with private;
type Abstract_Test_Case is abstract new AUnit.Test_Cases.Test_Case with private;

with package Tensors is new Orka.Numerics.Tensors (<>);

type Tensor_Type (<>) is new Tensors.Tensor with private;

type QR_Factorization_Type (<>) is new Tensors.QR_Factorization with private;

with procedure Initialize_Shaders
(Prefix_Sum, Tensors_GPU : Orka.Resources.Locations.Location_Ptr) is null;

with function Q (Object : QR_Factorization_Type'Class) return Tensor_Type is <>;
with function R (Object : QR_Factorization_Type'Class) return Tensor_Type is <>;

package Generic_Test_Tensors_Matrices is

function Suite return AUnit.Test_Suites.Access_Test_Suite;

private

type Test_Case is new Abstract_Test_Case with null record;

overriding
procedure Register_Tests (Object : in out Test_Case);

overriding
function Name (Object : Test_Case) return AUnit.Test_String;

end Generic_Test_Tensors_Matrices;
Loading

0 comments on commit c13af68

Please sign in to comment.