Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mem problems fix #36

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 66 additions & 55 deletions cpp_test/TestLsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ using namespace ldpc::lsd;
using namespace ldpc::sparse_matrix_util;


TEST(LsdCluster, init1) {
TEST(LsdCluster, init1){

auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(10);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary

// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 0;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

ASSERT_TRUE(cl.active);
ASSERT_FALSE(cl.valid);

Expand All @@ -35,25 +36,26 @@ TEST(LsdCluster, init1) {
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_boundary_check_nodes, cl.boundary_check_nodes);
ASSERT_EQ(expected_enclosed_syndromes, cl.enclosed_syndromes);
ASSERT_EQ(gcm[syndrome_index], &cl);
ASSERT_EQ(gcm->at(syndrome_index), &cl);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}


TEST(LsdCluster, add_bitANDadd_check_add) {
TEST(LsdCluster, add_bitANDadd_check_add){

auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(10);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n]; //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m]; //global check dictionary

// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n]; //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m]; //global check dictionary
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 1;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

cl.compute_growth_candidate_bit_nodes();
auto expected_candidate_bit_nodes = tsl::robin_set<int>{1, 2};
ASSERT_EQ(expected_candidate_bit_nodes, cl.candidate_bit_nodes);
Expand All @@ -70,27 +72,27 @@ TEST(LsdCluster, add_bitANDadd_check_add) {
{2, 1}};

ASSERT_EQ(expected_bit_nodes, cl.bit_nodes);
ASSERT_EQ(cl.global_bit_membership[2], &cl);
ASSERT_EQ(cl.global_bit_membership.get()->at(2), &cl);
ASSERT_EQ(cl.cluster_bit_idx_to_pcm_bit_idx[0], 2);
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);
ASSERT_EQ(cl.global_check_membership[1], &cl);
ASSERT_EQ(cl.global_check_membership[2], &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(1), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(2), &cl);


// Test adding existing checks and bits
cl.add_bit(*(++expected_candidate_bit_nodes.begin()));
cl.add_check(2, true);

ASSERT_EQ(expected_bit_nodes, cl.bit_nodes);
ASSERT_EQ(cl.global_bit_membership[2], &cl);
ASSERT_EQ(cl.global_bit_membership.get()->at(2), &cl);
ASSERT_EQ(cl.cluster_bit_idx_to_pcm_bit_idx[0], 2);
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);
ASSERT_EQ(cl.global_check_membership[1], &cl);
ASSERT_EQ(cl.global_check_membership[2], &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(1), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(2), &cl);

//check that bit is remove from boundary check node is removed from boundary check nodes
cl.compute_growth_candidate_bit_nodes();
Expand All @@ -103,18 +105,19 @@ TEST(LsdCluster, add_bitANDadd_check_add) {
auto expected_boundary_check_nodes = tsl::robin_set<int>{1};
ASSERT_EQ(expected_boundary_check_nodes, cl.boundary_check_nodes);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}

TEST(LsdCluster, add_bit_node_to_cluster) {
TEST(LsdCluster, add_bit_node_to_cluster){


auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(10);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary

// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 1;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

Expand All @@ -123,7 +126,7 @@ TEST(LsdCluster, add_bit_node_to_cluster) {
ASSERT_EQ(expected_candidate_bit_nodes, cl.candidate_bit_nodes);


auto bit_membership = cl.global_bit_membership[0];
auto bit_membership = cl.global_bit_membership.get()->at(0);

// add bit 2 to the cluster
cl.add_bit_node_to_cluster(2);
Expand All @@ -135,16 +138,16 @@ TEST(LsdCluster, add_bit_node_to_cluster) {
{2, 1}};

ASSERT_EQ(expected_bit_nodes, cl.bit_nodes);
ASSERT_EQ(cl.global_bit_membership[2], &cl);
ASSERT_EQ(cl.global_bit_membership.get()->at(2), &cl);
ASSERT_EQ(cl.cluster_bit_idx_to_pcm_bit_idx[0], 2);
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);
ASSERT_EQ(cl.global_check_membership[1], &cl);
ASSERT_EQ(cl.global_check_membership[2], &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(1), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(2), &cl);

cl.compute_growth_candidate_bit_nodes();
auto expected_boundary_check_nodes = tsl::robin_set<int>{1, 2};
auto expected_boundary_check_nodes = tsl::robin_set<int>{1,2};
expected_candidate_bit_nodes = tsl::robin_set<int>{1, 3};

ASSERT_EQ(expected_boundary_check_nodes, cl.boundary_check_nodes);
Expand Down Expand Up @@ -174,13 +177,13 @@ TEST(LsdCluster, add_bit_node_to_cluster) {
{0, 2}};

ASSERT_EQ(expected_bit_nodes, cl.bit_nodes);
ASSERT_EQ(cl.global_bit_membership[1], &cl);
ASSERT_EQ(cl.global_bit_membership.get()->at(1), &cl);
ASSERT_EQ(cl.cluster_bit_idx_to_pcm_bit_idx[0], 2);
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);
ASSERT_EQ(cl.global_check_membership[0], &cl);
ASSERT_EQ(cl.global_check_membership[2], &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(0), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(2), &cl);

//check the cluster pcm
expected_column = std::vector<int>{0, 1};
Expand All @@ -190,8 +193,8 @@ TEST(LsdCluster, add_bit_node_to_cluster) {
expected_column = std::vector<int>{2, 0};
ASSERT_EQ(expected_column, cl.cluster_pcm[1]);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}

Expand All @@ -200,16 +203,17 @@ TEST(LsdCluster, grow_cluster) {


auto pcm = ldpc::gf2codes::ring_code<ldpc::bp::BpEntry>(10);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary

// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 5;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

cl.compute_growth_candidate_bit_nodes();
auto expected_candidate_bit_nodes = tsl::robin_set<int>{5, 6};
ASSERT_EQ(expected_candidate_bit_nodes, cl.candidate_bit_nodes);
auto bit_membership = cl.global_bit_membership[5];
auto bit_membership = cl.global_bit_membership.get()->at(5);
ASSERT_EQ(bit_membership, nullptr);

cl.grow_cluster();
Expand All @@ -222,14 +226,14 @@ TEST(LsdCluster, grow_cluster) {
{6, 2}};

ASSERT_EQ(expected_bit_nodes, cl.bit_nodes);
ASSERT_EQ(cl.global_bit_membership[5], &cl);
ASSERT_EQ(cl.global_bit_membership.get()->at(5), &cl);
ASSERT_EQ(cl.cluster_bit_idx_to_pcm_bit_idx[0], 5);
ASSERT_EQ(expected_check_nodes, cl.check_nodes);
ASSERT_EQ(expected_cluster_check_idx_to_pcm_check_idx, cl.cluster_check_idx_to_pcm_check_idx);
ASSERT_EQ(expected_pcm_check_idx_to_cluster_check_idx, cl.pcm_check_idx_to_cluster_check_idx);
ASSERT_EQ(cl.global_check_membership[4], &cl);
ASSERT_EQ(cl.global_check_membership[5], &cl);
ASSERT_EQ(cl.global_check_membership[6], &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(4), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(5), &cl);
ASSERT_EQ(cl.global_check_membership.get()->at(6), &cl);

cl.compute_growth_candidate_bit_nodes();
auto expected_boundary_check_nodes = tsl::robin_set<int>{4, 6};
Expand All @@ -246,8 +250,8 @@ TEST(LsdCluster, grow_cluster) {
expected_column = std::vector<int>{0, 2};
ASSERT_EQ(expected_column, cl.cluster_pcm[1]);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}

Expand All @@ -256,10 +260,11 @@ TEST(LsdCluster, merge_clusters_test) {


auto pcm = ldpc::gf2codes::rep_code<ldpc::bp::BpEntry>(5);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary

// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
// auto syndrome_index = 0;
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto cl1 = ldpc::lsd::LsdCluster(pcm, 0, gcm, gbm);
auto cl2 = ldpc::lsd::LsdCluster(pcm, 3, gcm, gbm);

Expand All @@ -281,15 +286,17 @@ TEST(LsdCluster, merge_clusters_test) {

ASSERT_TRUE(cl2.valid);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}

TEST(LsdCluster, merge_clusters_otf_test) {
auto pcm = ldpc::gf2codes::rep_code<ldpc::bp::BpEntry>(5);
auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
// auto gbm = new ldpc::lsd::LsdCluster *[pcm.n](); //global bit dictionary
// auto gcm = new ldpc::lsd::LsdCluster *[pcm.m](); //global check dictionary
auto gbm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.n));
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto cl1 = ldpc::lsd::LsdCluster(pcm, 0, gcm, gbm);
auto cl2 = ldpc::lsd::LsdCluster(pcm, 2, gcm, gbm);

Expand Down Expand Up @@ -324,12 +331,13 @@ TEST(LsdCluster, merge_clusters_otf_test) {

ASSERT_EQ(decoding_syndrome, expected_syndrome);

delete gbm;
delete gcm;
// delete gbm;
// delete gcm;

}



TEST(LsdDecoder, otf_ring_code) {

for (auto length = 2; length < 12; length++) {
Expand Down Expand Up @@ -375,6 +383,7 @@ TEST(LsdDecoder, lsdw_decode) {
auto bp = ldpc::bp::BpDecoder(pcm, std::vector<double>(pcm.n, 0.1));
bp.maximum_iterations = 2;
auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::COMBINATION_SWEEP, 3);
lsd.lsd_order = 3;
for (int i = 0; i < std::pow(2, hamming_code_rank); i++) {
// std::cout << i << std::endl;
auto syndrome = ldpc::util::decimal_to_binary(i, hamming_code_rank);
Expand All @@ -394,7 +403,7 @@ TEST(LsdDecoder, lsdw_decode_ring_code) {
auto bp = ldpc::bp::BpDecoder(pcm, std::vector<double>(pcm.n, 0.1));
bp.maximum_iterations = 3;
auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::COMBINATION_SWEEP, 5);

lsd.lsd_order = 5;
for (int i = 0; i < std::pow(2, length); i++) {
auto error = ldpc::util::decimal_to_binary(i, length);
auto syndrome = pcm.mulvec(error);
Expand Down Expand Up @@ -433,6 +442,7 @@ TEST(LsdDecoder, test_fail_case) {
//setup the BP decoder with only 2 iterations
auto bp = ldpc::bp::BpDecoder(pcm, channel_probabilities, 100, ldpc::bp::MINIMUM_SUM, ldpc::bp::PARALLEL, 0.625);
auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::COMBINATION_SWEEP, 5);
lsd.lsd_order = 5;
bp.decode(syndrome);
auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true);
auto decoding_syndrome = pcm.mulvec(decoding);
Expand All @@ -452,6 +462,7 @@ TEST(LsdDecoder, test_cluster_stats) {
lsd.statistics.syndrome = std::vector<uint8_t>(pcm.m, 1);
lsd.statistics.compare_recover = std::vector<uint8_t>(pcm.n, 0);
auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true);
lsd.setLsdMethod(ldpc::osd::OsdMethod::EXHAUSTIVE);

auto stats = lsd.statistics;
std::cout << stats.toString() << std::endl;
Expand Down
Loading
Loading