diff --git a/sionna/nr/carrier_config.py b/sionna/nr/carrier_config.py index 1fb37647..441fdbf8 100644 --- a/sionna/nr/carrier_config.py +++ b/sionna/nr/carrier_config.py @@ -244,15 +244,27 @@ def kappa(self): @property def cyclic_prefix_length(self): r""" - float, read-only : Cyclic prefix length + float, read-only : Cyclic prefix length of all symbols except for the + first symbol in each half subframe :math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s] """ if self.cyclic_prefix=="extended": - cp = 512*self.kappa*2**(-self.mu) + cp = 512*self.kappa*2**(-self.mu) else: cp = 144*self.kappa*2**(-self.mu) - if self.slot_number in [0, 7*2**self.mu]: - cp += 16*self.kappa + return cp*self.t_c + + @property + def cyclic_prefix_length_first_symbol(self): + r""" + float, read-only : Cyclic prefix length of first symbol in each + half subframe + :math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s] + """ + if self.cyclic_prefix=="extended": + cp = 512*self.kappa*2**(-self.mu) + else: + cp = 144*self.kappa*2**(-self.mu) + 16*self.kappa return cp*self.t_c #-------------------# diff --git a/sionna/nr/pusch_config.py b/sionna/nr/pusch_config.py index c413dbf6..b40d2077 100644 --- a/sionna/nr/pusch_config.py +++ b/sionna/nr/pusch_config.py @@ -1050,8 +1050,11 @@ def check_pusch_configs(pusch_configs): "num_cdm_groups_without_data" : pc.dmrs.num_cdm_groups_without_data } params["bandwidth"] = params["num_subcarriers"]*params["subcarrier_spacing"] - params["cyclic_prefix_length"] = np.ceil(carrier.cyclic_prefix_length * - params["bandwidth"]) + params["cyclic_prefix_length"] = int(np.ceil(carrier.cyclic_prefix_length * + params["bandwidth"])) + params["cyclic_prefix_length_first_symbol"] =\ + int(np.ceil(carrier.cyclic_prefix_length_first_symbol + * params["bandwidth"])) for pusch_config in pusch_configs: if params["precoding"]=="codebook": diff --git a/sionna/nr/pusch_receiver.py b/sionna/nr/pusch_receiver.py index 996e1382..e0a1c99d 100644 --- a/sionna/nr/pusch_receiver.py +++ b/sionna/nr/pusch_receiver.py @@ -156,10 +156,16 @@ def __init__(self, assert l_min is not None, \ "l_min must be provided for input_domain==time" self._l_min = l_min + symbols_per_block = ( + pusch_transmitter._carrier_config.num_slots_per_subframe * + pusch_transmitter._carrier_config.num_symbols_per_slot // 2) self._ofdm_demodulator = OFDMDemodulator( fft_size=pusch_transmitter._num_subcarriers, l_min=self._l_min, - cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length) + cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length, + cyclic_prefix_length_first_symbol= + pusch_transmitter._cyclic_prefix_length_first_symbol, + symbols_per_block=symbols_per_block) # Use or create default ChannelEstimator self._perfect_csi = False diff --git a/sionna/nr/pusch_transmitter.py b/sionna/nr/pusch_transmitter.py index 7a4d62b8..e9b01c9e 100644 --- a/sionna/nr/pusch_transmitter.py +++ b/sionna/nr/pusch_transmitter.py @@ -183,7 +183,13 @@ def __init__(self, # (Optionally) Create OFDMModulator if self._output_domain=="time": - self._ofdm_modulator = OFDMModulator(self._cyclic_prefix_length) + symbols_per_block = (self._carrier_config.num_slots_per_subframe * + self._carrier_config.num_symbols_per_slot // 2) + self._ofdm_modulator = OFDMModulator( + cyclic_prefix_length=self._cyclic_prefix_length, + cyclic_prefix_length_first_symbol= + self._cyclic_prefix_length_first_symbol, + symbols_per_block=symbols_per_block) ######################################### # Public methods and properties diff --git a/sionna/ofdm/demodulator.py b/sionna/ofdm/demodulator.py index 925074e3..267946ef 100644 --- a/sionna/ofdm/demodulator.py +++ b/sionna/ofdm/demodulator.py @@ -15,11 +15,20 @@ class OFDMDemodulator(Layer): # pylint: disable=line-too-long r""" - OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs) + OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs) Computes the frequency-domain representation of an OFDM waveform with cyclic prefix removal. + When only `cyclic_prefix_length` is given then a cyclic prefix of length + `cyclic_prefix_length` is removed from each symbol. When additionally + `cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then + the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for + the first symbol of each block and `cyclic_prefix_length` for the + remaining symbols. For LTE one block corresponds to one slot (i.e., 7 + symbols). For 5G NR one block corresponds to one half subframe and the + number of symbols depends on the numerology. + The demodulator assumes that the input sequence is generated by the :class:`~sionna.channel.TimeChannel`. For a single pair of antennas, the received signal sequence is given as: @@ -49,7 +58,7 @@ class OFDMDemodulator(Layer): each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`. This is a very important step to enable channel estimation with sparse pilot patterns that needs to interpolate the channel frequency - response accross subcarriers. It also ensures that the + response across subcarriers. It also ensures that the channel frequency response `seen` by the time-domain channel is close to the :class:`~sionna.channel.OFDMChannel`. @@ -64,8 +73,16 @@ class OFDMDemodulator(Layer): `cir_to_time_channel` function. cyclic_prefix_length : int - Integer indicating the length of the cyclic prefix that - is prepended to each OFDM symbol. + Integer indicating the length of the cyclic prefix that it prepended + to each OFDM symbol (except for the first symbol of each block if + `cyclic_prefix_length_first_symbol` and `symbols per block` is given). + + cyclic_prefix_length_first_symbol : int + Integer indicating the length of the cyclic prefix that it prepended + to the first OFDM symbol of each block. + + symbols_per_block : int + Integer indicating the number of symbols per block. Input ----- @@ -80,11 +97,16 @@ class OFDMDemodulator(Layer): two dimension. """ - def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs): + def __init__(self, fft_size, l_min, cyclic_prefix_length=0, + cyclic_prefix_length_first_symbol=None, symbols_per_block=1, + **kwargs): super().__init__(**kwargs) self.fft_size = fft_size self.l_min = l_min self.cyclic_prefix_length = cyclic_prefix_length + self.cyclic_prefix_length_first_symbol =( + cyclic_prefix_length_first_symbol) + self.symbols_per_block = symbols_per_block @property def fft_size(self): @@ -92,7 +114,8 @@ def fft_size(self): @fft_size.setter def fft_size(self, value): - assert value>0, "`fft_size` must be positive." + assert isinstance(value, int) and value>0,\ + "`fft_size` must be a positive integer." self._fft_size = int(value) @property @@ -110,23 +133,61 @@ def cyclic_prefix_length(self): @cyclic_prefix_length.setter def cyclic_prefix_length(self, value): - assert value >=0, "`cyclic_prefix_length` must be nonnegative." + assert isinstance(value, int) and value >=0,\ + "`cyclic_prefix_length` must be a nonnegative integer." self._cyclic_prefix_length = int(value) - def build(self, input_shape): # pylint: disable=unused-argument + @property + def cyclic_prefix_length_first_symbol(self): + if self._cyclic_prefix_length_first_symbol is None: + return self._cyclic_prefix_length + else: + return self._cyclic_prefix_length_first_symbol + + @cyclic_prefix_length_first_symbol.setter + def cyclic_prefix_length_first_symbol(self, value): + assert (value is None or isinstance(value, int) and + value >= self._cyclic_prefix_length),\ + ("`cyclic_prefix_length_first_symbol` must be integer and " + + "larger or equal to `cyclic_prefix_length`.") + self._cyclic_prefix_length_first_symbol = value + + @property + def symbols_per_block(self): + return self._symbols_per_block + + @symbols_per_block.setter + def symbols_per_block(self, value): + assert isinstance(value, int) and value >= 1,\ + "`symbols_per_block` must be a positive integer." + self._symbols_per_block = value + + def build(self, input_shape): + num_samples = input_shape[-1] + tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \ / tf.cast(self.fft_size, tf.float32) \ * tf.range(self.fft_size, dtype=tf.float32) self._phase_compensation = tf.exp(tf.complex(0., tmp)) - # Compute number of elements that will be truncated - self._rest = np.mod(input_shape[-1], - self.fft_size + self.cyclic_prefix_length) - - # Compute number of full OFDM symbols to be demodulated - self._num_ofdm_symbols = np.floor_divide( - input_shape[-1]-self._rest, - self.fft_size + self.cyclic_prefix_length) + self._samples_per_block = (self.cyclic_prefix_length_first_symbol + + (self.symbols_per_block - 1) * self.cyclic_prefix_length + + self.symbols_per_block * self.fft_size) + + # Compute number of elements that will be truncated and number of + # symbols for padding + self._rest = num_samples % self._samples_per_block + samples_first_symbol = (self.cyclic_prefix_length_first_symbol + + self.fft_size) + samples_other_symbols = (self.cyclic_prefix_length + self.fft_size) + if self._rest > samples_first_symbol: + self._rest -= samples_first_symbol + excess_symbols = self._rest // samples_other_symbols + self._rest -= excess_symbols * samples_other_symbols + excess_symbols += 1 # Because of first symbol in block + self._num_pad_symbols = self.symbols_per_block - excess_symbols + else: + self._num_pad_symbols = 0 def call(self, inputs): """Demodulate OFDM waveform onto a resource grid. @@ -139,14 +200,37 @@ def call(self, inputs): `tf.complex64` : The demodulated inputs of shape `[...,num_ofdm_symbols, fft_size]`. """ + batch_dims = tf.shape(inputs)[:-1] # Cut last samples that do not fit into an OFDM symbol - inputs = inputs if self._rest==0 else inputs[...,:-self._rest] + x = inputs if self._rest == 0 else inputs[..., :-self._rest] + + if self._num_pad_symbols > 0: + pad_samples = self._num_pad_symbols * (self.fft_size + + self.cyclic_prefix_length) + padding_shape = tf.concat([batch_dims, [pad_samples]], axis=0) + padding = tf.zeros(padding_shape, dtype=x.dtype) + x = tf.concat([x, padding], axis=-1) + + # Reshape input to blocks + num_blocks = tf.shape(x)[-1] // self._samples_per_block + new_shape = tf.concat([batch_dims, + [num_blocks, self._samples_per_block]], 0) + x = tf.reshape(x, new_shape) + + # Remove extra cyclic prefix from first symbol + x = x[...,(self.cyclic_prefix_length_first_symbol - + self.cyclic_prefix_length):] # Reshape input to separate OFDM symbols - new_shape = tf.concat([tf.shape(inputs)[:-1], [self._num_ofdm_symbols], + new_shape = tf.concat([batch_dims, + [num_blocks * self.symbols_per_block], [self.fft_size + self.cyclic_prefix_length]], 0) - x = tf.reshape(inputs, new_shape) + x = tf.reshape(x, new_shape) + + # Remove padding + if self._num_pad_symbols > 0: + x = x[..., :-self._num_pad_symbols, :] # Remove cyclic prefix x = x[...,self.cyclic_prefix_length:] diff --git a/sionna/ofdm/modulator.py b/sionna/ofdm/modulator.py index 07d069f3..69d35f66 100644 --- a/sionna/ofdm/modulator.py +++ b/sionna/ofdm/modulator.py @@ -12,18 +12,37 @@ class OFDMModulator(Layer): + # pylint: disable=line-too-long """ - OFDMModulator(cyclic_prefix_length, **kwargs) + OFDMModulator(cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs) Computes the time-domain representation of an OFDM resource grid with (optional) cyclic prefix. + When only `cyclic_prefix_length` is given then a cyclic prefix of length + `cyclic_prefix_length` is prepended for each symbol. When additionally + `cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then + the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for + the first symbol of each block and `cyclic_prefix_length` for the + remaining symbols. For LTE one block corresponds to one slot (i.e., 7 + symbols). For 5G NR one block corresponds to one half subframe and the + number of symbols depends on the numerology. + Parameters ---------- cyclic_prefix_length : int - Integer indicating the length of the - cyclic prefix that it prepended to each OFDM symbol. It cannot - be longer than the FFT size. + Integer indicating the length of the cyclic prefix that it prepended + to each OFDM symbol (except for the first symbol of each block if + `cyclic_prefix_length_first_symbol` and `symbols per block` is given). + It cannot be longer than the FFT size. + + cyclic_prefix_length_first_symbol : int + Integer indicating the length of the cyclic prefix that it prepended + to the first OFDM symbol of each block. It cannot be longer than the + FFT size. + + symbols_per_block : int + Integer indicating the number of symbols per block. Input ----- @@ -36,9 +55,14 @@ class OFDMModulator(Layer): Time-domain OFDM signal. """ - def __init__(self, cyclic_prefix_length=0, **kwargs): + def __init__(self, cyclic_prefix_length=0, + cyclic_prefix_length_first_symbol=None, symbols_per_block=1, + **kwargs): super().__init__(**kwargs) self.cyclic_prefix_length = cyclic_prefix_length + self.cyclic_prefix_length_first_symbol = ( + cyclic_prefix_length_first_symbol) + self.symbols_per_block = symbols_per_block @property def cyclic_prefix_length(self): @@ -46,29 +70,94 @@ def cyclic_prefix_length(self): @cyclic_prefix_length.setter def cyclic_prefix_length(self, value): - assert value >=0, "`cyclic_prefix_length` must be nonnegative." + assert isinstance(value, int) and value >=0,\ + "`cyclic_prefix_length` must be a nonnegative integer." self._cyclic_prefix_length = value + @property + def cyclic_prefix_length_first_symbol(self): + if self._cyclic_prefix_length_first_symbol is None: + return self._cyclic_prefix_length + else: + return self._cyclic_prefix_length_first_symbol + + @cyclic_prefix_length_first_symbol.setter + def cyclic_prefix_length_first_symbol(self, value): + assert (value is None or isinstance(value, int) and + value >= self._cyclic_prefix_length),\ + ("`cyclic_prefix_length_first_symbol` must be larger or equal " + + "to `cyclic_prefix_length`.") + self._cyclic_prefix_length_first_symbol = value + + @property + def symbols_per_block(self): + return self._symbols_per_block + + @symbols_per_block.setter + def symbols_per_block(self, value): + assert isinstance(value, int) and value>=1,\ + "`symbols_per_block` must be a positive integer." + self._symbols_per_block = value + def build(self, input_shape): - # Verify that cyclic prefix is not longer than the FFT size. fft_size = input_shape[-1] + num_ofdm_symbols = input_shape[-2] + + # Verify that cyclic prefix is not longer than the FFT size. assert self.cyclic_prefix_length<=fft_size, \ "shape(inputs)[-1] must not be smaller than `cylic_prefix_length`" + assert self.cyclic_prefix_length_first_symbol <= fft_size, \ + ("shape(inputs)[-1] must not be smaller than " + + " `cylic_prefix_length_first_symbol`") + + # Compute padding size to fill the last block + self._num_pad_symbols = -num_ofdm_symbols % self.symbols_per_block def call(self, inputs): + fft_size = tf.shape(inputs)[-1] + num_ofdm_symbols = tf.shape(inputs)[-2] + batch_dims = tf.shape(inputs)[:-2] + # Shift DC subcarrier to first position inputs = ifftshift(inputs, axes=-1) # Compute IFFT along the last dimension x = ifft(inputs) + # Add padding to fill up last block + if self._num_pad_symbols != 0: + padding_shape = tf.concat([batch_dims, + [self._num_pad_symbols, fft_size]], axis=0) + padding = tf.zeros(padding_shape, dtype=x.dtype) + x = tf.concat([x, padding], axis=-2) + # Obtain cyclic prefix - cp = x[...,tf.shape(inputs)[-1]-self._cyclic_prefix_length:] + cp = x[...,fft_size-self.cyclic_prefix_length:] # Prepend cyclic prefix x = tf.concat([cp, x], -1) + # Reshape to blocks + num_blocks = tf.math.ceil(num_ofdm_symbols / self.symbols_per_block) + samples_per_block = (self.symbols_per_block * + (self.cyclic_prefix_length + fft_size)) + shape = tf.concat([batch_dims, + [num_blocks, samples_per_block]], axis=0) + x = tf.reshape(x, shape) + + # Obtain additional cyclic prefix for first symbol in block + cp = x[...,fft_size+self.cyclic_prefix_length- + self.cyclic_prefix_length_first_symbol:fft_size] + + # Prepend additional cyclic prefix + x = tf.concat([cp, x], -1) + # Serialize last two dimensions x = flatten_last_dims(x, 2) + # Remove padding + if self._num_pad_symbols != 0: + x = x[..., :-self._num_pad_symbols * + (self.cyclic_prefix_length + fft_size)] + return x diff --git a/test/unit/ofdm/test_ofdm.py b/test/unit/ofdm/test_ofdm.py index 0e3f7807..9d050fbc 100644 --- a/test/unit/ofdm/test_ofdm.py +++ b/test/unit/ofdm/test_ofdm.py @@ -50,6 +50,35 @@ def test_cyclic_prefixes(self): with self.assertRaises(AssertionError): x_time = modulator(x) + def test_nonuniform_cyclic_prefix(self): + batch_size = 64 + fft_size = 128 + num_ofdm_symbols = 14 + cp_length_first_symbol = 24 + cp_length_other_symbols = 16 + + for symbols_per_block in [7, 10, 20]: + modulator = OFDMModulator(cp_length_other_symbols, cp_length_first_symbol, + symbols_per_block) + qam_source = QAMSource(4) + x = qam_source([batch_size, num_ofdm_symbols, fft_size]) + x_time = modulator(x) + + x_np = x.numpy() + x_np = np.fft.ifft(np.fft.ifftshift(x_np, axes=-1), norm="ortho") + x_time_np = np.empty(shape=(batch_size, 0)) + for i in range(num_ofdm_symbols): + if i % symbols_per_block == 0: + x_time_np = np.concatenate([x_time_np, + x_np[..., i, -cp_length_first_symbol:], + x_np[..., i, :]], axis=-1) + else: + x_time_np = np.concatenate([x_time_np, + x_np[..., i, -cp_length_other_symbols:], + x_np[..., i, :]], axis=-1) + + np.testing.assert_array_almost_equal(x_time, x_time_np) + def test_higher_dimensions(self): batch_size = [64, 12, 6] fft_size = 72 @@ -76,6 +105,25 @@ def test_cyclic_prefixes(self): x_hat = demodulator(x_time) self.assertTrue(np.max(np.abs(x-x_hat))<1e-6) + def test_nonuniform_cyclic_prefix(self): + batch_size = 64 + fft_size = 128 + num_ofdm_symbols = 14 + cp_length_first_symbol = 24 + cp_length_other_symbols = 16 + + for symbols_per_block in [7, 10, 20]: + modulator = OFDMModulator(cp_length_other_symbols, + cp_length_first_symbol, symbols_per_block) + demodulator = OFDMDemodulator(fft_size, 0, cp_length_other_symbols, + cp_length_first_symbol, symbols_per_block) + qam_source = QAMSource(4) + x = qam_source([batch_size, num_ofdm_symbols, fft_size]) + x_time = modulator(x) + x_hat = demodulator(x_time) + + np.testing.assert_array_almost_equal(x, x_hat) + def test_higher_dimensions(self): batch_size = [64, 12, 6] fft_size = 72 @@ -105,7 +153,7 @@ def test_overlapping_input(self): class TestOFDMModDemod(unittest.TestCase): def test_end_to_end(self): - """E2E test verying that all shapes can be properly inferred (see Issue #7)""" + """E2E test verifying that all shapes can be properly inferred (see Issue #7)""" class E2ESystem(Model): def __init__(self, cp_length, padding): super().__init__() @@ -115,12 +163,12 @@ def __init__(self, cp_length, padding): self.num_ofdm_symbols = 14 self.qam = QAMSource(4) self.mod = OFDMModulator(self.cp_length) - self.demod = OFDMDemodulator(self.fft_size, 0, self.cp_length) + self.demod = OFDMDemodulator(self.fft_size, 0, self.cp_length) @tf.function(jit_compile=True) def call(self, batch_size): x_rg = self.qam([batch_size, 1, 1, self.num_ofdm_symbols, self.fft_size]) - x_time = self.mod(x_rg) + x_time = self.mod(x_rg) pad = tf.zeros_like(x_time)[...,:self.padding] x_time = tf.concat([x_time, pad], axis=-1) x_f = self.demod(x_time) @@ -184,7 +232,7 @@ def func(cp_length, num_tx, num_streams_per_tx): x_rg = rg_mapper(x) x_time = modulator(x_rg) y = demodulator(x_time) - # Stack inputs to ResourceGridDemppaer to simulate the data_dim dimension + # Stack inputs to ResourceGridDemapper to simulate the data_dim dimension y = tf.stack([y,y,y], axis=-1) x_hat = rg_demapper(y) x = tf.stack([x,x,x], axis=-1)