Skip to content

Commit

Permalink
minor: second implementation to BlendingContinuous
Browse files Browse the repository at this point in the history
A second implementation is provided for the matvec/rmatvec
routines of BlendingContinuous, where shifting is performed
simultaneously for all sources. This implementation is more
favourable when dealing with small number of receivers.
  • Loading branch information
mrava87 committed Sep 10, 2023
1 parent 63e7c48 commit 01e2c45
Showing 1 changed file with 82 additions and 27 deletions.
109 changes: 82 additions & 27 deletions pylops/waveeqprocessing/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class BlendingContinuous(LinearOperator):
Time sampling in seconds
times : :obj:`np.ndarray`
Absolute ignition times for each source
nproc : :obj:`int`, optional
Number of processors used when applying operator
shiftall : :obj:`bool`, optional
Shift all shots together (``True``) or one at the time (``False``). Defaults to ``shiftall=False`` (original
implementation), however ``shiftall=True`` should be preferred when ``nr`` is small.
dtype : :obj:`str`, optional
Operator dtype
name : :obj:`str`, optional
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
ns: int,
dt: float,
times: NDArray,
shiftall: bool = False,
dtype: DTypeLike = "float64",
name: str = "B",
) -> None:
Expand All @@ -73,40 +75,85 @@ def __init__(
self.ns = ns
self.dt = dt
self.times = times
self.shiftall = shiftall
self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype)
# Define shift operators
self.shifts = []
self.ShiftOps = []
for i in range(self.ns):
shift = self.times[i]
# This is the part that fits on the grid
shift_int = int(shift // self.dt)
self.shifts.append(shift_int)
# This is the fractional part
diff = (shift / self.dt - shift_int) * self.dt
if diff == 0:
self.ShiftOps.append(None)
else:
self.ShiftOps.append(
Shift(
(self.nr, self.nt + 1),
diff,
axis=1,
sampling=self.dt,
real=False,
dtype=self.dtype,
if not self.shiftall:
# original implementation, where each source is shifted indipendently
self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype)
# Define shift operators
self.shifts = []
self.ShiftOps = []
for i in range(self.ns):
shift = self.times[i]
# This is the part that fits on the grid
shift_int = int(shift // self.dt)
self.shifts.append(shift_int)
# This is the fractional part
diff = (shift / self.dt - shift_int) * self.dt
if diff == 0:
self.ShiftOps.append(None)
else:
self.ShiftOps.append(
Shift(
(self.nr, self.nt + 1),
diff,
axis=1,
sampling=self.dt,
real=True,
dtype=self.dtype,
)
)
)
else:
# alternative implementation, where all sources are shifted at the same time
self.PadOp = Pad(
(self.ns, self.nr, self.nt), ((0, 0), (0, 0), (0, 1)), dtype=self.dtype
)
# Define shift operator
self.shifts = (times // self.dt).astype(np.int32)
diff = (times / self.dt - self.shifts) * self.dt
diff = np.repeat(diff[:, np.newaxis], self.nr, axis=1)
self.ShiftOp = Shift(
(self.ns, self.nr, self.nt + 1),
diff,
axis=-1,
sampling=self.dt,
real=True,
dtype=self.dtype,
)
self.diff = diff

super().__init__(
dtype=np.dtype(dtype),
dims=(self.ns, self.nr, self.nt),
dimsd=(self.nr, self.nttot),
name=name,
)
self._register_multiplications()

@reshaped
def _matvec_smallrecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
blended_data = ncp.zeros((self.nr, self.nttot), dtype=self.dtype)
shifted_data = self.ShiftOp._matvec(self.PadOp._matvec(x.ravel())).reshape(
self.ns, self.nr, self.nt + 1
)
for i, shift_int in enumerate(self.shifts):
blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data[i]
return blended_data

@reshaped
def _rmatvec_smallrecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
shifted_data = ncp.zeros((self.ns, self.nr, self.nt + 1), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
shifted_data[i, :, :] = x[:, shift_int : shift_int + self.nt + 1]
deblended_data = self.PadOp._rmatvec(
self.ShiftOp._rmatvec(shifted_data.ravel())
).reshape(self.dims)
return deblended_data

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
def _matvec_largerecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
blended_data = ncp.zeros((self.nr, self.nttot), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
Expand All @@ -118,7 +165,7 @@ def _matvec(self, x: NDArray) -> NDArray:
return blended_data

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
def _rmatvec_largerecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
deblended_data = ncp.zeros((self.ns, self.nr, self.nt), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
Expand All @@ -133,6 +180,14 @@ def _rmatvec(self, x: NDArray) -> NDArray:
deblended_data[i, :, :] = shifted_data
return deblended_data

def _register_multiplications(self) -> None:
if self.shiftall:
self._matvec = self._matvec_smallrecs
self._rmatvec = self._rmatvec_smallrecs
else:
self._matvec = self._matvec_largerecs
self._rmatvec = self._rmatvec_largerecs


def BlendingGroup(
nt: int,
Expand Down

0 comments on commit 01e2c45

Please sign in to comment.