Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End of Over Data #35

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6fb88ba
wip insert some QPSK symbols in EOO frame
drowe67 Nov 22, 2024
91a77f0
bottleneck 3 for data symbols
drowe67 Nov 23, 2024
14fd1c4
wip EOO data basic rx
drowe67 Nov 25, 2024
46569da
inferernce.py wasn't adding phase and freq offsets to EOO!
drowe67 Nov 25, 2024
e364a63
based EOO data proof of concept (ML side) - getting data out with a f…
drowe67 Nov 25, 2024
d851f8f
wip fix failing ctests
drowe67 Nov 25, 2024
0ab0061
Merge branch 'main' into dr-eoo-data
drowe67 Nov 25, 2024
52b412b
used a custom RNG for EOO data bits, which stopped some test fails. …
drowe67 Nov 25, 2024
0bb2a12
comment out #weights prints, to reduce noise when running tests
drowe67 Nov 25, 2024
a006b07
allow for EOO not triggering at low SNR awgn test
drowe67 Nov 25, 2024
2dad457
prototype EOO rade_api.h
drowe67 Nov 26, 2024
8ccd3cb
Merge branch 'dr-eoo-data' of github.com:drowe67/radae into dr-eoo-data
drowe67 Nov 26, 2024
238facf
n_eoo_features out is number of floats rather than symbols
drowe67 Nov 26, 2024
a106dfc
wip EOO ouput data through C API
drowe67 Nov 28, 2024
39bf054
wip EOO data - about to refactor tests
drowe67 Nov 28, 2024
048ecc7
1st pass EOO data unit test framework
drowe67 Nov 28, 2024
3f8a23b
initial attempt at passing in tx EOO bits via function call (Python)
drowe67 Nov 28, 2024
bcb032f
numpy interafce for setting eoo bits
drowe67 Nov 28, 2024
09e9547
first pass of EOO data through C API
drowe67 Nov 28, 2024
1ca79a6
multipath channel test
drowe67 Nov 29, 2024
5c02edf
rade_rx() API change as per Mooneer's suggestion
drowe67 Nov 29, 2024
d294bcb
rm resource_est.py (moved to papers repo)
drowe67 Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ add_test(NAME radae_tx_basic
./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --auxdata --write_rx rx.f32 --correct_freq_offset; \
cat features_in.f32 | python3 radae_txe.py --model model19_check3/checkpoints/checkpoint_epoch_100.pth --txbpf > rx.f32
cat rx.f32 | python3 radae_rxe.py --model model19_check3/checkpoints/checkpoint_epoch_100.pth -v 1 > features_txs_out.f32; \
cat rx.f32 | python3 radae_rxe.py --model model19_check3/checkpoints/checkpoint_epoch_100.pth -v 2 > features_txs_out.f32; \
python3 loss.py features_in.f32 features_txs_out.f32 --loss_test 0.15 --acq_time_test 0.5 --clip_start 5")
set_tests_properties(radae_tx_basic PROPERTIES PASS_REGULAR_EXPRESSION "PASS")

Expand Down Expand Up @@ -291,10 +291,10 @@ add_test(NAME rx_streaming
# basic test of streaming rx, run rx in vanilla and streaming, compare
add_test(NAME radae_rx_basic
COMMAND sh -c "cd ${CMAKE_SOURCE_DIR}; \
./inference.sh model17/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \
--EbNodB 10 --freq_offset 11 \
./inference.sh model19_check3/checkpoints/checkpoint_epoch_100.pth wav/brian_g8sez.wav /dev/null \
--EbNodB 10 --freq_offset 11 --prepend_noise 1 --append_noise 1 --end_of_over --auxdata \
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --write_rx rx.f32 --correct_freq_offset; \
cat rx.f32 | PYTHONPATH='../' python3 radae_rxe.py --model model17/checkpoints/checkpoint_epoch_100.pth -v 1 --noauxdata > features_rxs_out.f32; \
cat rx.f32 | PYTHONPATH='../' python3 radae_rxe.py -v 2 --eoo_data_test > features_rxs_out.f32; \
python3 loss.py features_in.f32 features_rxs_out.f32 --loss_test 0.15 --acq_time_test 0.5")
set_tests_properties(radae_rx_basic PROPERTIES PASS_REGULAR_EXPRESSION "PASS")

Expand All @@ -309,7 +309,7 @@ add_test(NAME radae_rx_awgn
--rate_Fs --pilots --pilot_eq --eq_ls --cp 0.004 --bottleneck 3 --time_offset -16 --write_rx rx.f32 \
--prepend_noise 1 --append_noise 3 --end_of_over --auxdata --correct_freq_offset; \
cat rx.f32 | python3 radae_rxe.py --model model19_check3/checkpoints/checkpoint_epoch_100.pth -v 2 > features_rx_out.f32; \
python3 loss.py features_in.f32 features_rx_out.f32 --loss 0.3 --acq_time_test 1.0 --clip_end 100")
python3 loss.py features_in.f32 features_rx_out.f32 --loss 0.3 --acq_time_test 1.0 --clip_end 300")
set_tests_properties(radae_rx_awgn PROPERTIES PASS_REGULAR_EXPRESSION "PASS")

# SNR=0dB MPP
Expand Down
12 changes: 11 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
parser.add_argument('--rx_gain', type=float, default=1.0, help='gain to apply to --write_rx samples (default 1.0)')
parser.add_argument('--write_tx', type=str, default="", help='path to output file of rate Fs tx samples in ..IQIQ...f32 format')
parser.add_argument('--phase_offset', type=float, default=0, help='phase offset in rads')
parser.add_argument('--freq_offset', type=float, help='freq offset in Hz')
parser.add_argument('--freq_offset', type=float, default=0, help='freq offset in Hz')
parser.add_argument('--time_offset', type=int, default=0, help='sampling time offset in samples')
parser.add_argument('--df_dt', type=float, default=0, help='rate of change of freq offset in Hz/s')
parser.add_argument('--gain', type=float, default=1.0, help='rx gain (defaul 1.0)')
Expand Down Expand Up @@ -264,6 +264,16 @@
# appends a frame containing a final pilot so the last RADAE frame
# has a good phase reference, and two "end of over" symbols
eoo = model.eoo

# this is messy! - continue phase, freq and dF/dt track from inside forward()
freq = torch.zeros_like(eoo)
freq[:,] = model.freq_offset*torch.ones_like(eoo) + model.df_dt*torch.arange(eoo.shape[1])/model.Fs
omega = freq*2*torch.pi/model.Fs
lin_phase = torch.cumsum(omega,dim=1)
lin_phase = torch.exp(1j*lin_phase)
eoo = eoo*lin_phase*model.final_phase
#print(model.final_phase)

eoo = eoo + sigma*torch.randn_like(eoo)
rx = torch.concatenate([rx,eoo],dim=1)
if args.prepend_noise > 0.0:
Expand Down
44 changes: 30 additions & 14 deletions radae/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def transmitter_one(self, z, num_timesteps_at_rate_Rs):

# Single modem frame streaming receiver. TODO: is there a better way to pass a bunch of constants around?
class receiver_one():
def __init__(self,latent_dim,Fs,M,Ncp,Wfwd,Nc,Ns,w,P,bottleneck,pilot_gain,time_offset,coarse_mag):
def __init__(self,latent_dim,Fs,M,Ncp,Wfwd,Nc,Ns,w,P,Pend,bottleneck,pilot_gain,time_offset,coarse_mag):
self.latent_dim = latent_dim
self.Fs = Fs
self.M = M
Expand All @@ -354,6 +354,7 @@ def __init__(self,latent_dim,Fs,M,Ncp,Wfwd,Nc,Ns,w,P,bottleneck,pilot_gain,time_
self.Ns = Ns
self.w = w
self.P = P
self.Pend = Pend
self.bottleneck = bottleneck
self.pilot_gain = pilot_gain
self.time_offset = time_offset
Expand Down Expand Up @@ -415,7 +416,7 @@ def do_pilot_eq_one(self, num_modem_frames, rx_sym_pilots):
return rx_sym_pilots

# One frame version of rate Fs receiver for streaming implementation
def receiver_one(self, rx):
def receiver_one(self, rx, endofover):
Ns = self.Ns + 1

# we expect: Pilots - data symbols - Pilots
Expand All @@ -431,21 +432,36 @@ def receiver_one(self, rx):
# DFT to transform M time domain samples to Nc carriers
rx_sym = torch.matmul(rx_dash, self.Wfwd)

# Pilot based EQ
rx_sym_pilots = torch.reshape(rx_sym,(1, num_modem_frames, num_timesteps_at_rate_Rs, self.Nc))
rx_sym_pilots = self.do_pilot_eq_one(num_modem_frames,rx_sym_pilots)
rx_sym = torch.ones(1, num_modem_frames, self.Ns, self.Nc, dtype=torch.complex64)
rx_sym = rx_sym_pilots[:,:,1:self.Ns+1,:]

# demap QPSK symbols
rx_sym = torch.reshape(rx_sym, (1, -1, self.latent_dim//2))
z_hat = torch.zeros(1,rx_sym.shape[1], self.latent_dim)

z_hat[:,:,::2] = rx_sym.real
z_hat[:,:,1::2] = rx_sym.imag

if not endofover:
# Pilot based least squares EQ
rx_sym_pilots = self.do_pilot_eq_one(num_modem_frames,rx_sym_pilots)
rx_sym = rx_sym_pilots[:,:,1:self.Ns+1,:]
rx_sym = torch.reshape(rx_sym, (1, -1, self.latent_dim//2))
z_hat = torch.zeros(1,rx_sym.shape[1], self.latent_dim)

z_hat[:,:,::2] = rx_sym.real
z_hat[:,:,1::2] = rx_sym.imag
else:
#print(Ns,rx_sym_pilots.shape, file=sys.stderr)
# Simpler EQ as average of pilots, as LS set up for PDDDDP, rather than out PEDDDE
for c in range(self.Nc):
phase_offset = torch.angle(rx_sym_pilots[0,0,0,c]/self.P[c] +
rx_sym_pilots[0,0,1,c]/self.Pend[c] +
rx_sym_pilots[0,0,Ns,c]/self.Pend[c])
#print(phase_offset.shape, file=sys.stderr)
rx_sym_pilots[:,:,:Ns+1,c] *= torch.exp(-1j*phase_offset)
rx_sym = torch.reshape(rx_sym_pilots[:,:,2:Ns,:],(1,(Ns-2)*self.Nc))
#quit()
z_hat = torch.zeros(1,(Ns-2)*self.Nc*2)
#print(rx_sym.shape, z_hat.shape, file=sys.stderr)

z_hat[:,::2] = rx_sym.real
z_hat[:,1::2] = rx_sym.imag

return z_hat


# Generate root raised cosine (Root Nyquist) filter coefficients
# thanks http://www.dsplog.com/db-install/wp-content/uploads/2008/05/raised_cosine_filter.m

Expand Down
23 changes: 22 additions & 1 deletion radae/radae.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ def __init__(self,
self.pilot_gain = pilot_backoff*self.M/(Nc**0.5)

self.d_samples = int(self.multipath_delay * self.Fs) # multipath delay in samples
self.Ncp = int(cyclic_prefix*self.Fs)

# set up End Of Over sequence
# Normal frame ...PDDDDP...
# EOO frame ...PE000E...
# Key: P = self.p_cp, D = data symbols, E = self.pend_cp, 0 = zeros

if self.Ncp:
M = self.M
Ncp = self.Ncp
Expand All @@ -217,7 +217,26 @@ def __init__(self,
if self.bottleneck == 3:
eoo = torch.tanh(torch.abs(eoo)) * torch.exp(1j*torch.angle(eoo))
self.eoo = eoo

# experimental EOO data symbols (quick and dirty supplimentary txt channel)
self.Nseoo = (Ns-1)*Nc # number of EOO data symbols
# use a customer RNG to avoid upsetting some otehr rather delicate ctests (TODO fix this sensitvity later)
self.g = torch.Generator().manual_seed(1)
eoo_bits = torch.sign(torch.rand(self.Nseoo*bps,generator=self.g)-0.5)
self.eoo_bits = eoo_bits
eoo_syms = eoo_bits[::2] + 1j*eoo_bits[1::2]
eoo_syms = torch.reshape(eoo_syms,(1,Ns-1,Nc))

eoo_tx = torch.matmul(eoo_syms,self.Winv)
if self.Ncp:
eoo_tx_cp = torch.zeros((1,Ns-1,self.M+Ncp),dtype=torch.complex64)
eoo_tx_cp[:,:,Ncp:] = eoo_tx
eoo_tx_cp[:,:,:Ncp] = eoo_tx_cp[:,:,-Ncp:]
eoo_tx = torch.reshape(eoo_tx_cp,(1,(Ns-1)*(self.M+Ncp)))*self.pilot_gain
if self.bottleneck == 3:
eoo_tx = torch.tanh(torch.abs(eoo_tx)) * torch.exp(1j*torch.angle(eoo_tx))
self.eoo[0,2*(M+Ncp):Nmf] = eoo_tx

print(f"Rs: {Rs:5.2f} Rs': {Rs_dash:5.2f} Ts': {Ts_dash:5.3f} Nsmf: {Nsmf:3d} Ns: {Ns:3d} Nc: {Nc:3d} M: {self.M:d} Ncp: {self.Ncp:d}", file=sys.stderr)

self.Tmf = Tmf
Expand Down Expand Up @@ -482,6 +501,7 @@ def forward(self, features, H, G=None):

tx_before_channel = None
rx = None
self.final_phase = 0
if self.rate_Fs:
num_timesteps_at_rate_Fs = num_timesteps_at_rate_Rs*self.M

Expand Down Expand Up @@ -530,6 +550,7 @@ def forward(self, features, H, G=None):
lin_phase = torch.cumsum(omega,dim=1)
lin_phase = torch.exp(1j*lin_phase)
tx = tx*lin_phase
self.final_phase = lin_phase[:,-1]

# insert per sequence random phase and freq offset (training time)
if self.freq_rand:
Expand Down
8 changes: 4 additions & 4 deletions radae/radae_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(self, feature_dim, output_dim, bottleneck = 1):
self.z_dense = nn.Linear(864, self.output_dim)

nb_params = sum(p.numel() for p in self.parameters())
print(f"encoder: {nb_params} weights", file=sys.stderr)
#print(f"encoder: {nb_params} weights", file=sys.stderr)

# initialize weights
self.apply(init_weights)
Expand Down Expand Up @@ -251,7 +251,7 @@ def __init__(self, feature_dim, output_dim, bottleneck = 1):
self.z_dense = nn.Linear(864, self.output_dim)

nb_params = sum(p.numel() for p in self.parameters())
print(f"encoder: {nb_params} weights", file=sys.stderr)
#print(f"encoder: {nb_params} weights", file=sys.stderr)

# initialize weights
self.apply(init_weights)
Expand Down Expand Up @@ -326,7 +326,7 @@ def __init__(self, input_dim, output_dim):
self.glu5 = GLU(96)

nb_params = sum(p.numel() for p in self.parameters())
print(f"decoder: {nb_params} weights", file=sys.stderr)
#print(f"decoder: {nb_params} weights", file=sys.stderr)
# initialize weights
self.apply(init_weights)

Expand Down Expand Up @@ -393,7 +393,7 @@ def __init__(self, input_dim, output_dim):
self.glu5 = GLU(96)

nb_params = sum(p.numel() for p in self.parameters())
print(f"decoder: {nb_params} weights", file=sys.stderr)
#print(f"decoder: {nb_params} weights", file=sys.stderr)
# initialize weights
self.apply(init_weights)

Expand Down
48 changes: 31 additions & 17 deletions radae_rxe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
# P(accept|false) = binocdf(8,24,0.5) = 3.2E-3

class radae_rx:
def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3, bpf_en=True, v=2, disable_unsync=False, foff_err=0, bypass_dec=False):
def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3, bpf_en=True, v=2,
disable_unsync=False, foff_err=0, bypass_dec=False, eoo_data_test=False):

self.latent_dim = latent_dim
self.auxdata = auxdata
Expand All @@ -64,6 +65,7 @@ def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3,
self.disable_unsync = disable_unsync
self.foff_err = foff_err
self.bypass_dec = bypass_dec
self.eoo_data_test = eoo_data_test

print(f"bypass_dec: {bypass_dec} foff_err: {foff_err:f}", file=sys.stderr)

Expand All @@ -85,7 +87,7 @@ def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3,
assert self.model.coarse_mag

self.receiver = receiver_one(model.latent_dim,model.Fs,model.M,model.Ncp,model.Wfwd,model.Nc,
model.Ns,model.w,model.P,model.bottleneck,model.pilot_gain,
model.Ns,model.w,model.P,model.Pend,model.bottleneck,model.pilot_gain,
model.time_offset,model.coarse_mag)

M = model.M
Expand Down Expand Up @@ -115,7 +117,7 @@ def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3,
# Stateful decoder wasn't present during training, so we need to load weights from existing decoder
model.core_decoder_statefull_load_state_dict()

# number of input floats per processing frame
# number of output floats per processing frame
if not self.bypass_dec:
self.n_floats_out = model.Nzmf*model.enc_stride*nb_total_features
else:
Expand All @@ -142,6 +144,9 @@ def __init__(self, model_name, latent_dim = 80, auxdata = True, bottleneck = 3,
def get_n_features_out(self):
return self.model.Nzmf*self.model.dec_stride*nb_total_features

def get_n_eoo_features_out(self):
return self.model.Nseoo

def get_n_floats_out(self):
return self.n_floats_out

Expand Down Expand Up @@ -212,21 +217,28 @@ def do_radae_rx(self, buffer_complex, floats_out):
uw_fail = True
self.uw_errors = 0

# correct frequency offset, note we preserve state of phase (TODO: I don't think we need to)
w = 2*np.pi*self.fmax/Fs
rx_phase_vec = np.zeros(Nmf+M+Ncp,np.csingle)
for n in range(Nmf+M+Ncp):
self.rx_phase = self.rx_phase*np.exp(-1j*w)
rx_phase_vec[n] = self.rx_phase
rx1 = rx_buf[self.tmax-Ncp:self.tmax-Ncp+Nmf+M+Ncp]
rx = torch.tensor(rx1*rx_phase_vec, dtype=torch.complex64)

# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx, endofover)
if not endofover:
# correct frequency offset, note we preserve state of phase
# TODO do we need preserve state of phase? We're passing entire vector and there isn't any memory (I think)
w = 2*np.pi*self.fmax/Fs
rx_phase_vec = np.zeros(Nmf+M+Ncp,np.csingle)
for n in range(Nmf+M+Ncp):
self.rx_phase = self.rx_phase*np.exp(-1j*w)
rx_phase_vec[n] = self.rx_phase
rx1 = rx_buf[self.tmax-Ncp:self.tmax-Ncp+Nmf+M+Ncp]
rx = torch.tensor(rx1*rx_phase_vec, dtype=torch.complex64)

# run through RADAE receiver DSP
z_hat = receiver.receiver_one(rx)
valid_output = True

else:
if self.eoo_data_test:
n_bits = torch.numel(z_hat)
assert n_bits == model.Nseoo*model.bps
n_errors = sum(z_hat[0,:]*model.eoo_bits < 0)
print(f"EOO data n_bits: {n_bits} n_errors: {n_errors}", file=sys.stderr)
z_hat = z_hat.cpu().detach().numpy().flatten()
z_hat.tofile("z_hat_eoo.f32")

if v == 2 or (v == 1 and (self.state == "search" or self.state == "candidate" or prev_state == "candidate")):
print(f"{self.mf:3d} state: {self.state:10s} valid: {candidate:d} {endofover:d} {self.valid_count:2d} Dthresh: {acq.Dthresh:8.2f} ", end='', file=sys.stderr)
print(f"Dtmax12: {acq.Dtmax12:8.2f} {acq.Dtmax12_eoo:8.2f} tmax: {self.tmax:4d} fmax: {self.fmax:6.2f}", end='', file=sys.stderr)
Expand Down Expand Up @@ -316,11 +328,13 @@ def do_radae_rx(self, buffer_complex, floats_out):
parser.add_argument('--no_stdout', action='store_false', dest='use_stdout', help='disable the use of stdout (e.g. with python3 -m cProfile)')
parser.add_argument('--foff_err', type=float, default=0.0, help='Artifical freq offset error after first sync to test false sync (default 0.0)')
parser.add_argument('--bypass_dec', action='store_true', help='Bypass core decoder, write z_hat to stdout')
parser.add_argument('--eoo_data_test', action='store_true', help='experimental EOO data test - count bit errors')
parser.set_defaults(auxdata=True)
parser.set_defaults(use_stdout=True)
args = parser.parse_args()

rx = radae_rx(args.model_name,auxdata=args.auxdata,v=args.v,disable_unsync=args.disable_unsync,foff_err=args.foff_err, bypass_dec=args.bypass_dec)
rx = radae_rx(args.model_name,auxdata=args.auxdata,v=args.v,disable_unsync=args.disable_unsync,foff_err=args.foff_err,
bypass_dec=args.bypass_dec,eoo_data_test=args.eoo_data_test)

# allocate storage for output features
floats_out = np.zeros(rx.get_n_floats_out(),dtype=np.float32)
Expand Down
10 changes: 8 additions & 2 deletions src/radae_rx.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ int main(int argc, char *argv[])
int n_rx_in = rade_nin_max(r);
RADE_COMP rx_in[n_rx_in];
int nin = rade_nin(r);

int n_eoo_features_out = rade_n_eoo_features_out(r);
FILE *feoo = fopen("eoo.f32","wb"); assert(feoo != NULL);

#ifdef _WIN32
// Note: freopen() returns NULL if filename is NULL, so
// we have to use setmode() to make it a binary stream instead.
Expand All @@ -37,14 +39,18 @@ int main(int argc, char *argv[])

while((size_t)nin == fread(rx_in, sizeof(RADE_COMP), nin, stdin)) {
int n_out = rade_rx(r,features_out,rx_in);
if (n_out) {
if (n_out == n_features_out) {
fwrite(features_out, sizeof(float), n_features_out, stdout);
fflush(stdout);
}
if (n_out == n_eoo_features_out) {
fwrite(features_out, sizeof(float), n_eoo_features_out, feoo);
}
nin = rade_nin(r);
}

rade_close(r);
rade_finalize();
fclose(feoo);
return 0;
}
Loading