Skip to content

Commit

Permalink
* add reset py method,
Browse files Browse the repository at this point in the history
* fix resetting at each decoding run for bplsd, also if BP converges
  • Loading branch information
lucasberent committed May 20, 2024
1 parent 3077149 commit aaefb0e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
20 changes: 17 additions & 3 deletions cpp_test/TestLsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,10 @@ TEST(LsdDecoder, test_cluster_stats) {
auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::EXHAUSTIVE, 0);
lsd.set_do_stats(true);
auto syndrome = std::vector<uint8_t>({1, 1, 0, 0, 0});
lsd.statistics.error = std::vector<uint8_t>(pcm.n, 1);
lsd.statistics.syndrome = std::vector<uint8_t>(pcm.m, 1);
lsd.statistics.compare_recover = std::vector<uint8_t>(pcm.n, 0);
auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true);
lsd.statistics.compare_recover = std::vector<uint8_t>(pcm.n, 0);
lsd.statistics.error = std::vector<uint8_t>(pcm.n, 1);

auto stats = lsd.statistics;
std::cout << stats.toString() << std::endl;
Expand All @@ -464,7 +464,7 @@ TEST(LsdDecoder, test_cluster_stats) {
ASSERT_TRUE(stats.global_timestep_bit_history[0].size() == 2);
ASSERT_TRUE(stats.global_timestep_bit_history[0][0].size() == 1);
ASSERT_TRUE(stats.global_timestep_bit_history[0][1].size() == 2);
ASSERT_TRUE(stats.global_timestep_bit_history[1].size() == 0);
ASSERT_TRUE(stats.global_timestep_bit_history[1].empty());
ASSERT_TRUE(stats.elapsed_time > 0.0);
ASSERT_TRUE(stats.individual_cluster_stats[0].active == false);
ASSERT_TRUE(stats.individual_cluster_stats[0].got_inactive_in_timestep == 0);
Expand All @@ -476,6 +476,20 @@ TEST(LsdDecoder, test_cluster_stats) {
ASSERT_TRUE(stats.error.size() == pcm.n);
ASSERT_TRUE(stats.syndrome.size() == pcm.n);
ASSERT_TRUE(stats.compare_recover.size() == pcm.n);

// now reset
lsd.reset_cluster_stats();
stats = lsd.statistics;
ASSERT_TRUE(lsd.get_do_stats());
ASSERT_TRUE(stats.lsd_method = ldpc::osd::OsdMethod::COMBINATION_SWEEP);
ASSERT_TRUE(stats.lsd_order == 0);
ASSERT_TRUE(stats.individual_cluster_stats.empty());
ASSERT_TRUE(stats.elapsed_time == 0.0);
ASSERT_TRUE(stats.global_timestep_bit_history.empty());
ASSERT_TRUE(stats.bit_llrs.empty());
ASSERT_TRUE(stats.error.empty());
ASSERT_TRUE(stats.syndrome.empty());
ASSERT_TRUE(stats.compare_recover.empty());
}

TEST(LsdDecoder, test_reshuffle_same_wt_indices) {
Expand Down
12 changes: 12 additions & 0 deletions src_cpp/lsd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,13 @@ namespace ldpc::lsd {
void clear() {
this->individual_cluster_stats.clear();
this->global_timestep_bit_history.clear();
this->elapsed_time = 0.0;
this->lsd_method = osd::OsdMethod::COMBINATION_SWEEP;
this->lsd_order = 0;
this->bit_llrs = {};
this->error = {};
this->syndrome= {};
this->compare_recover = {};
}

[[nodiscard]] std::string toString() const {
Expand Down Expand Up @@ -555,6 +562,10 @@ namespace ldpc::lsd {
osd::OsdMethod lsd_method;
int lsd_order;

void reset_cluster_stats(){
this->statistics.clear();
}

explicit LsdDecoder(ldpc::bp::BpSparse &parity_check_matrix,
osd::OsdMethod lsdMethod = osd::OsdMethod::COMBINATION_SWEEP,
int lsd_order = 0) : pcm(parity_check_matrix),
Expand Down Expand Up @@ -618,6 +629,7 @@ namespace ldpc::lsd {
const bool is_on_the_fly = true) {
auto start_time = std::chrono::high_resolution_clock::now();
this->statistics.clear();
this->statistics.syndrome = syndrome;

fill(this->decoding.begin(), this->decoding.end(), 0);

Expand Down
11 changes: 11 additions & 0 deletions src_python/ldpc/bplsd_decoder/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,14 @@ class BpLsdDecoder(BpDecoderBase):
"""


def set_additional_stat_fields(self, error, syndrome, compare_recover) -> None:
"""
Sets additional fields to be collected in the statistics.
Parameters
----------
fields : List[str]
A list of strings representing the additional fields to be collected in the statistics.
"""

1 change: 1 addition & 0 deletions src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cdef extern from "lsd.hpp" namespace "ldpc::lsd":
bool get_do_stats()
void set_do_stats(bool do_stats)
void set_additional_stat_fields(vector[int] error, vector[int] syndrome, vector[int] compare_recover)
void reset_cluster_stats()

cdef class BpLsdDecoder(BpDecoderBase):
cdef LsdDecoderCpp* lsd
Expand Down
12 changes: 11 additions & 1 deletion src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ cdef class BpLsdDecoder(BpDecoderBase):
self.bpd.decoding = self.bpd.decode(self._syndrome)
out = np.zeros(self.n,dtype=DTYPE)
if self.bpd.converge:
for i in range(self.n): out[i] = self.bpd.decoding[i]
for i in range(self.n):
out[i] = self.bpd.decoding[i]
self.lsd.reset_cluster_stats()


if not self.bpd.converge:
self.lsd.decoding = self.lsd.lsd_decode(self._syndrome, self.bpd.log_prob_ratios,self.bits_per_step, True)
Expand Down Expand Up @@ -291,3 +294,10 @@ cdef class BpLsdDecoder(BpDecoderBase):
self.lsd.statistics.error = error
self.lsd.statistics.syndrome = syndrome
self.lsd.statistics.compare_recover = compare_recover

def reset_cluster_stats() -> None:
"""
Resets cluster statistics of the decoder.
Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors
"""
self.lsd.reset_cluster_stats()

0 comments on commit aaefb0e

Please sign in to comment.