Skip to content

Commit

Permalink
wip EOO ouput data through C API
Browse files Browse the repository at this point in the history
  • Loading branch information
drowe67 committed Nov 28, 2024
1 parent 238facf commit a106dfc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 26 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ add_test(NAME radae_tx_embed_c
add_test(NAME radae_rx_embed_c
COMMAND sh -c "cd ${CMAKE_SOURCE_DIR}; \
./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --auxdata --write_rx rx.f32 --correct_freq_offset; \
--EbNodB 100 --freq_offset 0 --append_noise 1 --end_of_over --auxdata \
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --write_rx rx.f32 --correct_freq_offset; \
cat rx.f32 | PYTHONPATH='.' ${CMAKE_CURRENT_BINARY_DIR}/src/radae_rx > features_out.f32;
python3 loss.py features_in.f32 features_out.f32 --loss_test 0.15 --acq_time_test 0.5")
set_tests_properties(radae_rx_embed_c PROPERTIES PASS_REGULAR_EXPRESSION "PASS")
Expand Down
4 changes: 2 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@
lin_phase = torch.cumsum(omega,dim=1)
lin_phase = torch.exp(1j*lin_phase)
eoo = eoo*lin_phase*model.final_phase
#print(model.final_phase)

print(model.final_phase)
print(lin_phase)
eoo = eoo + sigma*torch.randn_like(eoo)
rx = torch.concatenate([rx,eoo],dim=1)
if args.prepend_noise > 0.0:
Expand Down
6 changes: 1 addition & 5 deletions radae/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,14 @@ def receiver_one(self, rx, endofover):
z_hat[:,:,::2] = rx_sym.real
z_hat[:,:,1::2] = rx_sym.imag
else:
#print(Ns,rx_sym_pilots.shape, file=sys.stderr)
# Simpler EQ as average of pilots, as LS set up for PDDDDP, rather than out PEDDDE
# Simpler (but lower performance) EQ as average of pilots, as LS set up for PDDDDP, rather than out PEDDDE
for c in range(self.Nc):
phase_offset = torch.angle(rx_sym_pilots[0,0,0,c]/self.P[c] +
rx_sym_pilots[0,0,1,c]/self.Pend[c] +
rx_sym_pilots[0,0,Ns,c]/self.Pend[c])
#print(phase_offset.shape, file=sys.stderr)
rx_sym_pilots[:,:,:Ns+1,c] *= torch.exp(-1j*phase_offset)
rx_sym = torch.reshape(rx_sym_pilots[:,:,2:Ns,:],(1,(Ns-2)*self.Nc))
#quit()
z_hat = torch.zeros(1,(Ns-2)*self.Nc*2)
#print(rx_sym.shape, z_hat.shape, file=sys.stderr)

z_hat[:,::2] = rx_sym.real
z_hat[:,1::2] = rx_sym.imag
Expand Down
2 changes: 1 addition & 1 deletion radae/radae.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def forward(self, features, H, G=None):

tx_before_channel = None
rx = None
self.final_phase = 0
self.final_phase = torch.tensor(1,dtype=torch.complex64)
if self.rate_Fs:
num_timesteps_at_rate_Fs = num_timesteps_at_rate_Rs*self.M

Expand Down
31 changes: 18 additions & 13 deletions radae_rxe.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,7 @@ def do_radae_rx(self, buffer_complex, floats_out):

# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx, endofover)
if not endofover:
valid_output = True
else:
if self.eoo_data_test:
n_bits = torch.numel(z_hat)
assert n_bits == model.Nseoo*model.bps
n_errors = sum(z_hat[0,:]*model.eoo_bits < 0)
print(f"EOO data n_bits: {n_bits} n_errors: {n_errors}", file=sys.stderr)
z_hat = z_hat.cpu().detach().numpy().flatten()
z_hat.tofile("z_hat_eoo.f32")
valid_output = not endofover

if v == 2 or (v == 1 and (self.state == "search" or self.state == "candidate" or prev_state == "candidate")):
print(f"{self.mf:3d} state: {self.state:10s} valid: {candidate:d} {endofover:d} {self.valid_count:2d} Dthresh: {acq.Dthresh:8.2f} ", end='', file=sys.stderr)
Expand Down Expand Up @@ -317,7 +308,21 @@ def do_radae_rx(self, buffer_complex, floats_out):
else:
np.copyto(floats_out, z_hat.cpu().detach().numpy().flatten().astype('float32'))

return valid_output
if endofover:
n_bits = torch.numel(z_hat)
assert n_bits == model.Nseoo*model.bps
if self.eoo_data_test:
n_errors = sum(z_hat[0,:]*model.eoo_bits < 0)
print(f"EOO data n_bits: {n_bits} n_errors: {n_errors}", file=sys.stderr)
z_hat = z_hat.cpu().detach().numpy().flatten()
np.copyto(floats_out,np.concatenate([z_hat,np.zeros(len(floats_out)-len(z_hat))]))

# possible return cases
# valid_output | endofover | Description
# 0 0 Nothing returned
# 1 0 valid speech output (either z_hat or features, depending on bypass_dec)
# 0 1 EOO data output
return valid_output | endofover<<1

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RADAE streaming receiver, IQ.f32 on stdin to features.f32 on stdout')
Expand All @@ -343,6 +348,6 @@ def do_radae_rx(self, buffer_complex, floats_out):
if len(buffer) != rx.get_nin()*struct.calcsize("ff"):
break
buffer_complex = np.frombuffer(buffer,np.csingle)
valid_output = rx.do_radae_rx(buffer_complex, floats_out)
if valid_output and args.use_stdout:
ret = rx.do_radae_rx(buffer_complex, floats_out)
if (ret & 1) and args.use_stdout:
sys.stdout.buffer.write(floats_out)
14 changes: 10 additions & 4 deletions src/rade_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,10 @@ int rade_rx(struct rade *r, float features_out[], RADE_COMP rx_in[]) {
memcpy(r->rx_in, rx_in, sizeof(RADE_COMP)*(r->nin));
pValue = PyObject_CallObject(r->pMeth_radae_rx, r->pArgs_radae_rx);
check_error(pValue, "return value", "from do_rx_radae");
long valid_out = PyLong_AsLong(pValue);

long ret = PyLong_AsLong(pValue);
int valid_out = ret & 0x1;
int endofover = ret & 0x2;
fprintf(stderr, "%ld %d %d\n", ret, valid_out, endofover);
if (valid_out) {
if (r->flags & RADE_USE_C_DECODER) {
// sanity check: need integer number of latent vecs
Expand Down Expand Up @@ -482,6 +484,9 @@ int rade_rx(struct rade *r, float features_out[], RADE_COMP rx_in[]) {
}
}

if (endofover)
memcpy(features_out, r->floats_out, sizeof(float)*(r->n_eoo_features_out));

// sample nin so we have an updated copy
r->nin = (int)call_getter(r->pInst_radae_rx, "get_nin");

Expand All @@ -490,8 +495,9 @@ int rade_rx(struct rade *r, float features_out[], RADE_COMP rx_in[]) {

if (valid_out)
return r->n_features_out;
else
return 0;
if (endofover)
return r->n_eoo_features_out;
return 0;
}

int rade_sync(struct rade *r) {
Expand Down

0 comments on commit a106dfc

Please sign in to comment.