diff --git a/test/quadpool_op_test.cc b/test/quadpool_op_test.cc new file mode 100644 index 0000000..c43a00b --- /dev/null +++ b/test/quadpool_op_test.cc @@ -0,0 +1,135 @@ +#define BOOST_TEST_DYN_LINK +#define BOOST_TEST_MODULE quadpool + +#include +#include +#include + +#include + +#include + + +using namespace torch_geopooling; + + +template +std::function +exception_contains_text(const std::string error_message) +{ + return [&](const Exception& error) -> bool { + return std::string(error.what()).find(error_message) != std::string::npos; + }; +} + + +BOOST_AUTO_TEST_SUITE(TestQuadPoolOperation) + + +BOOST_AUTO_TEST_CASE(quadpool_op_tiles_errors) +{ + auto op = quadpool_op("test_op", {0.0, 0.0, 1.0, 1.0}, quadtree_options(), /*training=*/true); + + auto tiles = torch::empty({0, 3, 5}, torch::TensorOptions().dtype(torch::kInt64)); + auto weight = torch::rand({0, 1}, torch::TensorOptions().dtype(torch::kFloat64)); + auto input = torch::rand({100, 2}, torch::TensorOptions().dtype(torch::kFloat64)); + + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("operation only supports 2D tiles") + ); + + tiles = torch::empty({0, 4}, torch::TensorOptions().dtype(torch::kInt64)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("tiles must be three-element tuples") + ); + + tiles = torch::empty({0, 3}, torch::TensorOptions().dtype(torch::kFloat64)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("operation only supports Int64 tiles") + ); +} + + +BOOST_AUTO_TEST_CASE(quadpool_op_weight_errors) +{ + auto op = quadpool_op("test_op", {-.1, -1., 2., 2.}, quadtree_options(), /*training=*/true); + + auto tiles = torch::ones({10, 3}, torch::TensorOptions().dtype(torch::kInt64)); + auto weight = torch::rand({5, 3}, torch::TensorOptions().dtype(torch::kFloat64)); + auto input = torch::rand({10, 2}, torch::TensorOptions().dtype(torch::kFloat64)); + + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("number of tiles should be the same as weights") + ); + + weight = torch::rand({5, 5, 1}, torch::TensorOptions().dtype(torch::kFloat64)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("operation only supports 2D weight") + ); +} + + +BOOST_AUTO_TEST_CASE(quadpool_op_inputs) +{ + auto op = quadpool_op("test_op", {0.0, 0.0, 20.0, 20.0}, quadtree_options(), /*training=*/true); + + auto tiles = torch::empty({0, 3}, torch::TensorOptions().dtype(torch::kInt64)); + auto weight = torch::rand({0, 1}, torch::TensorOptions().dtype(torch::kFloat64)); + auto input = torch::rand({100, 2}, torch::TensorOptions().dtype(torch::kFloat64)); + + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input - 300.0), value_error, + exception_contains_text("is outside of exterior geometry") + ); + + input = torch::rand({10, 2, 5, 1}, torch::TensorOptions().dtype(torch::kFloat64)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("operation only supports 2D input") + ); + + input = torch::rand({10, 7}, torch::TensorOptions().dtype(torch::kFloat64)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("input must be two-element tuples") + ); + + input = torch::empty({10, 2}, torch::TensorOptions().dtype(torch::kInt32)); + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), c10::Error, + exception_contains_text("operation only supports Float64 input") + ); +} + + +BOOST_AUTO_TEST_CASE(quadpool_op_parentless_tiles) +{ + auto op = quadpool_op("test_op", {0.0, 0.0, 20.0, 20.0}, quadtree_options(), /*training=*/true); + + auto tiles = torch::tensor( + { + {1, 1, 1}, + {9, 0, 0}, + {1, 0, 0}, + {8, 0, 0}, + {4, 0, 0}, + }, + torch::TensorOptions().dtype(torch::kInt64) + ); + + auto weight = torch::rand({tiles.size(0), 1}, torch::TensorOptions().dtype(torch::kFloat64)); + auto input = torch::rand({10, 2}, torch::TensorOptions().dtype(torch::kFloat64)); + + BOOST_CHECK_EXCEPTION( + op.forward(tiles, weight, input), value_error, + exception_contains_text("does not have a parent") + ); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/test/tile_test.cc b/test/tile_test.cc new file mode 100644 index 0000000..f254f09 --- /dev/null +++ b/test/tile_test.cc @@ -0,0 +1,37 @@ +#define BOOST_TEST_DYN_LINK +#define BOOST_TEST_MODULE quadtree + +#include + +#include + + +using namespace torch_geopooling; + + +BOOST_AUTO_TEST_SUITE(TestTile) + + +BOOST_AUTO_TEST_CASE(tile_constructor) +{ + BOOST_CHECK_THROW(Tile(1024, 0, 0), value_error); + BOOST_CHECK_THROW(Tile(2, 1024, 0), value_error); + BOOST_CHECK_THROW(Tile(4, 0, 1024), value_error); +} + + +BOOST_AUTO_TEST_CASE(tile_children) +{ + auto parent = Tile(62, 10, 10); + auto children = parent.children(); + + BOOST_REQUIRE_EQUAL(children.size(), 4); + + BOOST_CHECK_EQUAL(children[0], Tile(63, 20, 20)); + BOOST_CHECK_EQUAL(children[1], Tile(63, 20, 21)); + BOOST_CHECK_EQUAL(children[2], Tile(63, 21, 20)); + BOOST_CHECK_EQUAL(children[3], Tile(63, 21, 21)); +} + + +BOOST_AUTO_TEST_SUITE_END()