Skip to content

Commit

Permalink
Explicitly copy real part in constant R2C transforms to avoid complex…
Browse files Browse the repository at this point in the history
…warning
  • Loading branch information
kburns committed Jun 6, 2020
1 parent 449f7c3 commit 2e9286b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions dedalus/core/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,22 +269,28 @@ class Transform:
# To Do: group transforms for multiple fields

def __init__(self, layout0, layout1, axis, basis):

self.layout0 = layout0
self.layout1 = layout1
self.axis = axis
self.basis = basis
if (np.issubdtype(layout0.dtype, np.complexfloating) and
np.issubdtype(layout1.dtype, np.floating)):
self.copyto = self._copyto_real_part
else:
self.copyto = np.copyto

@staticmethod
def _copyto_real_part(output, input):
np.copyto(output, np.real(input))

@CachedMethod
def group_data(self, nfields, scales):

local_shape0 = self.layout0.local_shape(scales)
local_shape1 = self.layout1.local_shape(scales)
group_shape0 = [nfields] + list(local_shape0)
group_shape1 = [nfields] + list(local_shape1)
group_cdata = fftw.create_array(group_shape0, self.layout0.dtype)
group_gdata = fftw.create_array(group_shape1, self.layout1.dtype)

return group_cdata, group_gdata

def increment_group(self, fields):
Expand Down Expand Up @@ -331,7 +337,7 @@ def increment_single(self, field):
if cdata.size:
# Shortcut constant transforms
if field.meta[self.axis]['constant']:
gdata[:] = cdata[axslice(self.axis, 0, 1)]
self.copyto(gdata, cdata[axslice(self.axis, 0, 1)])
else:
self.basis.backward(cdata, gdata, self.axis, field.meta[self.axis], field.scales[self.axis])

Expand Down

0 comments on commit 2e9286b

Please sign in to comment.