Skip to content

Commit

Permalink
Fix cyclic prefix length of first symbol for 5g NR PUSCH
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Schäufele <[email protected]>
  • Loading branch information
danielschaeufele committed Jun 7, 2024
1 parent 2cb12fd commit 52d2b41
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 39 deletions.
20 changes: 16 additions & 4 deletions sionna/nr/carrier_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,27 @@ def kappa(self):
@property
def cyclic_prefix_length(self):
r"""
float, read-only : Cyclic prefix length
float, read-only : Cyclic prefix length of all symbols except for the
first symbol in each half subframe
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
"""
if self.cyclic_prefix=="extended":
cp = 512*self.kappa*2**(-self.mu)
cp = 512*self.kappa*2**(-self.mu)
else:
cp = 144*self.kappa*2**(-self.mu)
if self.slot_number in [0, 7*2**self.mu]:
cp += 16*self.kappa
return cp*self.t_c

@property
def cyclic_prefix_length_first_symbol(self):
r"""
float, read-only : Cyclic prefix length of first symbol in each
half subframe
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
"""
if self.cyclic_prefix=="extended":
cp = 512*self.kappa*2**(-self.mu)
else:
cp = 144*self.kappa*2**(-self.mu) + 16*self.kappa
return cp*self.t_c

#-------------------#
Expand Down
7 changes: 5 additions & 2 deletions sionna/nr/pusch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,8 +1050,11 @@ def check_pusch_configs(pusch_configs):
"num_cdm_groups_without_data" : pc.dmrs.num_cdm_groups_without_data
}
params["bandwidth"] = params["num_subcarriers"]*params["subcarrier_spacing"]
params["cyclic_prefix_length"] = np.ceil(carrier.cyclic_prefix_length *
params["bandwidth"])
params["cyclic_prefix_length"] = int(np.ceil(carrier.cyclic_prefix_length *
params["bandwidth"]))
params["cyclic_prefix_length_first_symbol"] =\
int(np.ceil(carrier.cyclic_prefix_length_first_symbol
* params["bandwidth"]))

for pusch_config in pusch_configs:
if params["precoding"]=="codebook":
Expand Down
8 changes: 7 additions & 1 deletion sionna/nr/pusch_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,16 @@ def __init__(self,
assert l_min is not None, \
"l_min must be provided for input_domain==time"
self._l_min = l_min
symbols_per_block = (
pusch_transmitter._carrier_config.num_slots_per_subframe *
pusch_transmitter._carrier_config.num_symbols_per_slot // 2)
self._ofdm_demodulator = OFDMDemodulator(
fft_size=pusch_transmitter._num_subcarriers,
l_min=self._l_min,
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length)
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length,
cyclic_prefix_length_first_symbol=
pusch_transmitter._cyclic_prefix_length_first_symbol,
symbols_per_block=symbols_per_block)

# Use or create default ChannelEstimator
self._perfect_csi = False
Expand Down
8 changes: 7 additions & 1 deletion sionna/nr/pusch_transmitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,13 @@ def __init__(self,

# (Optionally) Create OFDMModulator
if self._output_domain=="time":
self._ofdm_modulator = OFDMModulator(self._cyclic_prefix_length)
symbols_per_block = (self._carrier_config.num_slots_per_subframe *
self._carrier_config.num_symbols_per_slot // 2)
self._ofdm_modulator = OFDMModulator(
cyclic_prefix_length=self._cyclic_prefix_length,
cyclic_prefix_length_first_symbol=
self._cyclic_prefix_length_first_symbol,
symbols_per_block=symbols_per_block)

#########################################
# Public methods and properties
Expand Down
122 changes: 103 additions & 19 deletions sionna/ofdm/demodulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@
class OFDMDemodulator(Layer):
# pylint: disable=line-too-long
r"""
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs)
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs)
Computes the frequency-domain representation of an OFDM waveform
with cyclic prefix removal.
When only `cyclic_prefix_length` is given then a cyclic prefix of length
`cyclic_prefix_length` is removed from each symbol. When additionally
`cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then
the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for
the first symbol of each block and `cyclic_prefix_length` for the
remaining symbols. For LTE one block corresponds to one slot (i.e., 7
symbols). For 5G NR one block corresponds to one half subframe and the
number of symbols depends on the numerology.
The demodulator assumes that the input sequence is generated by the
:class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
the received signal sequence is given as:
Expand Down Expand Up @@ -49,7 +58,7 @@ class OFDMDemodulator(Layer):
each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`.
This is a very important step to enable channel estimation with
sparse pilot patterns that needs to interpolate the channel frequency
response accross subcarriers. It also ensures that the
response across subcarriers. It also ensures that the
channel frequency response `seen` by the time-domain channel
is close to the :class:`~sionna.channel.OFDMChannel`.
Expand All @@ -64,8 +73,16 @@ class OFDMDemodulator(Layer):
`cir_to_time_channel` function.
cyclic_prefix_length : int
Integer indicating the length of the cyclic prefix that
is prepended to each OFDM symbol.
Integer indicating the length of the cyclic prefix that it prepended
to each OFDM symbol (except for the first symbol of each block if
`cyclic_prefix_length_first_symbol` and `symbols per block` is given).
cyclic_prefix_length_first_symbol : int
Integer indicating the length of the cyclic prefix that it prepended
to the first OFDM symbol of each block.
symbols_per_block : int
Integer indicating the number of symbols per block.
Input
-----
Expand All @@ -80,19 +97,25 @@ class OFDMDemodulator(Layer):
two dimension.
"""

def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs):
def __init__(self, fft_size, l_min, cyclic_prefix_length=0,
cyclic_prefix_length_first_symbol=None, symbols_per_block=1,
**kwargs):
super().__init__(**kwargs)
self.fft_size = fft_size
self.l_min = l_min
self.cyclic_prefix_length = cyclic_prefix_length
self.cyclic_prefix_length_first_symbol =(
cyclic_prefix_length_first_symbol)
self.symbols_per_block = symbols_per_block

@property
def fft_size(self):
return self._fft_size

@fft_size.setter
def fft_size(self, value):
assert value>0, "`fft_size` must be positive."
assert isinstance(value, int) and value>0,\
"`fft_size` must be a positive integer."
self._fft_size = int(value)

@property
Expand All @@ -110,23 +133,61 @@ def cyclic_prefix_length(self):

@cyclic_prefix_length.setter
def cyclic_prefix_length(self, value):
assert value >=0, "`cyclic_prefix_length` must be nonnegative."
assert isinstance(value, int) and value >=0,\
"`cyclic_prefix_length` must be a nonnegative integer."
self._cyclic_prefix_length = int(value)

def build(self, input_shape): # pylint: disable=unused-argument
@property
def cyclic_prefix_length_first_symbol(self):
if self._cyclic_prefix_length_first_symbol is None:
return self._cyclic_prefix_length
else:
return self._cyclic_prefix_length_first_symbol

@cyclic_prefix_length_first_symbol.setter
def cyclic_prefix_length_first_symbol(self, value):
assert (value is None or isinstance(value, int) and
value >= self._cyclic_prefix_length),\
("`cyclic_prefix_length_first_symbol` must be integer and " +
"larger or equal to `cyclic_prefix_length`.")
self._cyclic_prefix_length_first_symbol = value

@property
def symbols_per_block(self):
return self._symbols_per_block

@symbols_per_block.setter
def symbols_per_block(self, value):
assert isinstance(value, int) and value >= 1,\
"`symbols_per_block` must be a positive integer."
self._symbols_per_block = value

def build(self, input_shape):
num_samples = input_shape[-1]

tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \
/ tf.cast(self.fft_size, tf.float32) \
* tf.range(self.fft_size, dtype=tf.float32)
self._phase_compensation = tf.exp(tf.complex(0., tmp))

# Compute number of elements that will be truncated
self._rest = np.mod(input_shape[-1],
self.fft_size + self.cyclic_prefix_length)

# Compute number of full OFDM symbols to be demodulated
self._num_ofdm_symbols = np.floor_divide(
input_shape[-1]-self._rest,
self.fft_size + self.cyclic_prefix_length)
self._samples_per_block = (self.cyclic_prefix_length_first_symbol +
(self.symbols_per_block - 1) * self.cyclic_prefix_length +
self.symbols_per_block * self.fft_size)

# Compute number of elements that will be truncated and number of
# symbols for padding
self._rest = num_samples % self._samples_per_block
samples_first_symbol = (self.cyclic_prefix_length_first_symbol +
self.fft_size)
samples_other_symbols = (self.cyclic_prefix_length + self.fft_size)
if self._rest > samples_first_symbol:
self._rest -= samples_first_symbol
excess_symbols = self._rest // samples_other_symbols
self._rest -= excess_symbols * samples_other_symbols
excess_symbols += 1 # Because of first symbol in block
self._num_pad_symbols = self.symbols_per_block - excess_symbols
else:
self._num_pad_symbols = 0

def call(self, inputs):
"""Demodulate OFDM waveform onto a resource grid.
Expand All @@ -139,14 +200,37 @@ def call(self, inputs):
`tf.complex64` : The demodulated inputs of shape
`[...,num_ofdm_symbols, fft_size]`.
"""
batch_dims = tf.shape(inputs)[:-1]

# Cut last samples that do not fit into an OFDM symbol
inputs = inputs if self._rest==0 else inputs[...,:-self._rest]
x = inputs if self._rest == 0 else inputs[..., :-self._rest]

if self._num_pad_symbols > 0:
pad_samples = self._num_pad_symbols * (self.fft_size +
self.cyclic_prefix_length)
padding_shape = tf.concat([batch_dims, [pad_samples]], axis=0)
padding = tf.zeros(padding_shape, dtype=x.dtype)
x = tf.concat([x, padding], axis=-1)

# Reshape input to blocks
num_blocks = tf.shape(x)[-1] // self._samples_per_block
new_shape = tf.concat([batch_dims,
[num_blocks, self._samples_per_block]], 0)
x = tf.reshape(x, new_shape)

# Remove extra cyclic prefix from first symbol
x = x[...,(self.cyclic_prefix_length_first_symbol -
self.cyclic_prefix_length):]

# Reshape input to separate OFDM symbols
new_shape = tf.concat([tf.shape(inputs)[:-1], [self._num_ofdm_symbols],
new_shape = tf.concat([batch_dims,
[num_blocks * self.symbols_per_block],
[self.fft_size + self.cyclic_prefix_length]], 0)
x = tf.reshape(inputs, new_shape)
x = tf.reshape(x, new_shape)

# Remove padding
if self._num_pad_symbols > 0:
x = x[..., :-self._num_pad_symbols, :]

# Remove cyclic prefix
x = x[...,self.cyclic_prefix_length:]
Expand Down
Loading

0 comments on commit 52d2b41

Please sign in to comment.