-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathgenerations.py
34 lines (29 loc) · 1.24 KB
/
generations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from Global_parameters import *
channel_train = np.load('channel_train.npy')
train_size = channel_train.shape[0]
channel_test = np.load('channel_test.npy')
test_size = channel_test.shape[0]
def training_gen(bs, SNRdb = 20):
while True:
index = np.random.choice(np.arange(train_size), size=bs)
H_total = channel_train[index]
input_samples = []
input_labels = []
for H in H_total:
bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,))
signal_output, para = ofdm_simulate(bits, H, SNRdb)
input_labels.append(bits[0:16])
input_samples.append(signal_output)
yield (np.asarray(input_samples), np.asarray(input_labels))
def validation_gen(bs, SNRdb = 20):
while True:
index = np.random.choice(np.arange(train_size), size=bs)
H_total = channel_train[index]
input_samples = []
input_labels = []
for H in H_total:
bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,))
signal_output, para = ofdm_simulate(bits, H, SNRdb)
input_labels.append(bits[0:16])
input_samples.append(signal_output)
yield (np.asarray(input_samples), np.asarray(input_labels))