diff --git a/CMakeLists.txt b/CMakeLists.txt index 673ab42..13e8a64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,7 +294,7 @@ add_test(NAME radae_rx_basic ./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \ --EbNodB 10 --freq_offset 11 --prepend_noise 1 --append_noise 1 --end_of_over --auxdata \ --rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --write_rx rx.f32 --correct_freq_offset; \ - cat rx.f32 | PYTHONPATH='../' python3 radae_rxe.py -v 2 --eoo_data_test > features_rxs_out.f32; \ + cat rx.f32 | PYTHONPATH='../' python3 radae_rxe.py -v 2 > features_rxs_out.f32; \ python3 loss.py features_in.f32 features_rxs_out.f32 --loss_test 0.15 --acq_time_test 0.5") set_tests_properties(radae_rx_basic PROPERTIES PASS_REGULAR_EXPRESSION "PASS") @@ -552,6 +552,18 @@ add_test(NAME c_decoder_aux_mpp python3 loss.py features_in.f32 features_c.f32 --loss 0.3 --clip_start 300") set_tests_properties(c_decoder_aux_mpp PROPERTIES PASS_REGULAR_EXPRESSION "PASS") +# EOO data ------------------------------------------------------------------------------------------- + +add_test(NAME radae_eoo_data_py + COMMAND sh -c "cd ${CMAKE_SOURCE_DIR}; \ + ./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \ + --EbNodB 10 --freq_offset 13 \ + --rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --time_offset -16 --write_rx rx.f32 \ + --prepend_noise 1 --append_noise 1 --end_of_over --auxdata --correct_freq_offset; \ + cat rx.f32 | python3 radae_rxe.py -v 2 --eoo_data_test > /dev/null") + set_tests_properties(radae_eoo_data_py PROPERTIES PASS_REGULAR_EXPRESSION "PASS") + + # BBFM ----------------------------------------------------------------------------------------------- # single carrier modem internal (inside single_carrier class) tests diff --git a/inference.py b/inference.py index 60f75c6..75f148b 100644 --- a/inference.py +++ b/inference.py @@ -240,7 +240,7 @@ n_errors = int(torch.sum(x < 0)) n_bits = int(torch.numel(x)) BER = n_errors/n_bits - print(f"loss: {loss:5.3f} BER: {BER:5.3f}") + print(f"loss: {loss:5.3f} Auxdata BER: {BER:5.3f}") else: print(f"loss: {loss:5.3f}") if args.loss_test > 0.0: diff --git a/radae_rxe.py b/radae_rxe.py index a4d251e..127d8e0 100644 --- a/radae_rxe.py +++ b/radae_rxe.py @@ -309,11 +309,6 @@ def do_radae_rx(self, buffer_complex, floats_out): np.copyto(floats_out, z_hat.cpu().detach().numpy().flatten().astype('float32')) if endofover: - n_bits = torch.numel(z_hat) - assert n_bits == model.Nseoo*model.bps - if self.eoo_data_test: - n_errors = sum(z_hat[0,:]*model.eoo_bits < 0) - print(f"EOO data n_bits: {n_bits} n_errors: {n_errors}", file=sys.stderr) z_hat = z_hat.cpu().detach().numpy().flatten() np.copyto(floats_out,np.concatenate([z_hat,np.zeros(len(floats_out)-len(z_hat))])) @@ -351,3 +346,11 @@ def do_radae_rx(self, buffer_complex, floats_out): ret = rx.do_radae_rx(buffer_complex, floats_out) if (ret & 1) and args.use_stdout: sys.stdout.buffer.write(floats_out) + if (ret & 2) and args.eoo_data_test: + n_bits = rx.model.Nseoo*rx.model.bps + tx_bits = rx.model.eoo_bits.cpu().detach().numpy().flatten() + n_errors = sum(floats_out[:n_bits]*tx_bits < 0) + ber = n_errors/n_bits + print(f"EOO data n_bits: {n_bits} n_errors: {n_errors} BER: {ber:5.2f}", file=sys.stderr) + if ber < 0.05: + print("PASS", file=sys.stderr)