Skip to content

Commit

Permalink
Fix bug making TaskBasedCpuContractor different from TensorNetwork (#6)
Browse files Browse the repository at this point in the history
* Fix bug

* Update tests for TN fixed ordering

* Fix casting issue

* Transpose matrix A in vector-matrix GEMV binding

Co-authored-by: Trevor Vincent <[email protected]>
Co-authored-by: Lee J. O'Riordan <[email protected]>
Co-authored-by: Lee James O'Riordan <[email protected]>
Co-authored-by: Mikhail Andrenkov <[email protected]>
  • Loading branch information
5 people authored May 17, 2021
1 parent 10af598 commit 5949562
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 42 deletions.
11 changes: 6 additions & 5 deletions include/jet/TensorHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,19 @@ gemmBinding(size_t m, size_t n, size_t k, ComplexPrecision alpha,
* @param A_data Complex data matrix A
* @param B_data Complex data vector B
* @param C_data Output data vector
* @param transpose Transpose flag for matrix A
*/
template <typename ComplexPrecision>
constexpr void
gemvBinding(size_t m, size_t k, ComplexPrecision alpha, ComplexPrecision beta,
const ComplexPrecision *A_data, const ComplexPrecision *B_data,
ComplexPrecision *C_data)
ComplexPrecision *C_data, const CBLAS_TRANSPOSE &transpose)
{
if constexpr (std::is_same_v<ComplexPrecision, std::complex<float>>)
cblas_cgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
cblas_cgemv(CblasRowMajor, transpose, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
cblas_zgemv(CblasRowMajor, transpose, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
};

Expand Down Expand Up @@ -153,12 +154,12 @@ inline void MultiplyTensorData(const std::vector<ComplexPrecision> &A,
else if (left_indices.size() > 0 && right_indices.size() == 0) {
size_t m = left_dim;
size_t k = common_dim;
gemvBinding(m, k, alpha, beta, A_data, B_data, C_data);
gemvBinding(m, k, alpha, beta, A_data, B_data, C_data, CblasNoTrans);
}
else if (left_indices.size() == 0 && right_indices.size() > 0) {
size_t n = right_dim;
size_t k = common_dim;
gemvBinding(k, n, alpha, beta, B_data, A_data, C_data);
gemvBinding(k, n, alpha, beta, B_data, A_data, C_data, CblasTrans);
}
else if (left_indices.size() == 0 && right_indices.size() == 0) {
size_t k = common_dim;
Expand Down
7 changes: 0 additions & 7 deletions include/jet/TensorNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,6 @@ template <class Tensor> class TensorNetwork {
*/
size_t ContractNodes_(size_t node_id_1, size_t node_id_2) noexcept
{
// Make sure node 1 has at least as many indices as node 2.
const size_t node_1_size = nodes_[node_id_1].tensor.GetIndices().size();
const size_t node_2_size = nodes_[node_id_2].tensor.GetIndices().size();
if (node_1_size <= node_2_size) {
std::swap(node_id_1, node_id_2);
}

auto &node_1 = nodes_[node_id_1];
auto &node_2 = nodes_[node_id_2];
const auto tensor_3 = ContractTensors(node_1.tensor, node_2.tensor);
Expand Down
40 changes: 32 additions & 8 deletions test/Test_Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,22 @@ TEMPLATE_TEST_CASE("ContractTensors", "[Tensor]", c_fp32, c_fp64)
Approx(expected_rji_si.GetData()[1].imag()));

CHECK(con_si_rij.GetData()[0].real() ==
Approx(expected_rji_si.GetData()[0].real()));
Approx(expected_rij_si.GetData()[0].real()));
CHECK(con_si_rij.GetData()[0].imag() ==
Approx(expected_rji_si.GetData()[0].imag()));
Approx(expected_rij_si.GetData()[0].imag()));
CHECK(con_si_rij.GetData()[1].real() ==
Approx(expected_rji_si.GetData()[1].real()));
Approx(expected_rij_si.GetData()[1].real()));
CHECK(con_si_rij.GetData()[1].imag() ==
Approx(expected_rji_si.GetData()[1].imag()));
Approx(expected_rij_si.GetData()[1].imag()));

CHECK(con_si_rji.GetData()[0].real() ==
Approx(expected_rij_si.GetData()[0].real()));
Approx(expected_rji_si.GetData()[0].real()));
CHECK(con_si_rji.GetData()[0].imag() ==
Approx(expected_rij_si.GetData()[0].imag()));
Approx(expected_rji_si.GetData()[0].imag()));
CHECK(con_si_rji.GetData()[1].real() ==
Approx(expected_rij_si.GetData()[1].real()));
Approx(expected_rji_si.GetData()[1].real()));
CHECK(con_si_rji.GetData()[1].imag() ==
Approx(expected_rij_si.GetData()[1].imag()));
Approx(expected_rji_si.GetData()[1].imag()));
}

SECTION("Contract T0(a,b) and T1(b) -> T2(a)")
Expand All @@ -502,6 +502,30 @@ TEMPLATE_TEST_CASE("ContractTensors", "[Tensor]", c_fp32, c_fp64)
CHECK(tensor3 == tensor4);
}

SECTION("Contract T0(a) and T1(a,b) -> T2(b)")
{
std::vector<std::size_t> t_shape1{2};
std::vector<std::size_t> t_shape2{2, 2};
std::vector<std::size_t> t_shape3{2};

std::vector<std::string> t_indices1{"a"};
std::vector<std::string> t_indices2{"a", "b"};

std::vector<TestType> t_data1{TestType(0.0, 0.0), TestType(1.0, 0.0)};
std::vector<TestType> t_data2{TestType(0.0, 0.0), TestType(1.0, 0.0),
TestType(2.0, 0.0), TestType(3.0, 0.0)};
std::vector<TestType> t_data_expect{TestType(2.0, 0.0),
TestType(3.0, 0.0)};

Tensor<TestType> tensor1(t_indices1, t_shape1, t_data1);
Tensor<TestType> tensor2(t_indices2, t_shape2, t_data2);

Tensor<TestType> tensor3 = ContractTensors(tensor1, tensor2);
Tensor<TestType> tensor4({"b"}, {2}, t_data_expect);

CHECK(tensor3 == tensor4);
}

SECTION("Contract T0(a,b,c) and T1(b,c,d) -> T2(a,d)")
{
std::vector<std::size_t> t_shape1{2, 3, 4};
Expand Down
52 changes: 30 additions & 22 deletions test/Test_TensorNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ tensor_t MakeTensor(const indices_t &indices, const shape_t &shape)
if (!shape.empty()) {
for (size_t i = 0; i < tensor.GetSize(); i++) {
const auto index = Jet::Utilities::UnravelIndex(i, shape);
tensor.SetValue(index, i);
tensor.SetValue(index, complex_t{static_cast<float>(i),
static_cast<float>(2 * i)});
}
}
return tensor;
Expand Down Expand Up @@ -208,9 +209,11 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23};
const data_t want_tensor_data = {
{0, 0}, {1, 2}, {2, 4}, {3, 6}, {4, 8}, {5, 10},
{6, 12}, {7, 14}, {8, 16}, {9, 18}, {10, 20}, {11, 22},
{12, 24}, {13, 26}, {14, 28}, {15, 30}, {16, 32}, {17, 34},
{18, 36}, {19, 38}, {20, 40}, {21, 42}, {22, 44}, {23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -236,7 +239,9 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
const data_t want_tensor_data = {{0, 0}, {1, 2}, {2, 4}, {3, 6},
{4, 8}, {5, 10}, {6, 12}, {7, 14},
{8, 16}, {9, 18}, {10, 20}, {11, 22}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -262,8 +267,9 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23};
const data_t want_tensor_data = {
{12, 24}, {13, 26}, {14, 28}, {15, 30}, {16, 32}, {17, 34},
{18, 36}, {19, 38}, {20, 40}, {21, 42}, {22, 44}, {23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -289,7 +295,7 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {14, 18, 22};
const data_t want_tensor_data = {{14, 28}, {18, 36}, {22, 44}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -315,7 +321,7 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {23};
const data_t want_tensor_data = {{23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}
}
Expand All @@ -336,7 +342,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {0, 1};
const data_t want_tensor_data = {{0, 0}, {1, 2}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -365,7 +371,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = {result.GetValue({})};
const data_t want_tensor_data = {5};
const data_t want_tensor_data = {{-15, 20}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -402,7 +408,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {10, 13};
const data_t want_tensor_data = {{-30, 40}, {-39, 52}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -439,7 +445,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = {result.GetValue({})};
const data_t want_tensor_data = {55};
const data_t want_tensor_data = {{-165, 220}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -476,11 +482,12 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const shape_t have_tensor_shape = result.GetShape();
const shape_t want_tensor_shape = {3, 2};
const shape_t want_tensor_shape = {2, 3};
REQUIRE(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {5, 14, 14, 50, 23, 86};
const data_t want_tensor_data = {{-15, 20}, {-42, 56}, {-69, 92},
{-42, 56}, {-150, 200}, {-258, 344}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand All @@ -492,8 +499,8 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
{
const auto &node = nodes[2];
CHECK(node.id == 2);
CHECK(node.name == "C2A0");
CHECK(node.indices == indices_t{"C2", "A0"});
CHECK(node.name == "A0C2");
CHECK(node.indices == indices_t{"A0", "C2"});
CHECK(node.contracted == false);
}

Expand Down Expand Up @@ -524,7 +531,8 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
REQUIRE(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {28, 100, 47, 164};
const data_t want_tensor_data = {
{-308, -56}, {-517, -94}, {-1100, -200}, {-1804, -328}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand All @@ -537,15 +545,15 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
{
const auto &node = nodes[3];
CHECK(node.id == 3);
CHECK(node.name == "D3B1");
CHECK(node.indices == indices_t{"D3", "B1"});
CHECK(node.name == "B1D3");
CHECK(node.indices == indices_t{"B1", "D3"});
CHECK(node.contracted == true);
}
{
const auto &node = nodes[4];
CHECK(node.id == 4);
CHECK(node.name == "D3A0");
CHECK(node.indices == indices_t{"D3", "A0"});
CHECK(node.name == "A0D3");
CHECK(node.indices == indices_t{"A0", "D3"});
CHECK(node.contracted == false);
}

Expand Down

0 comments on commit 5949562

Please sign in to comment.