Skip to content

Commit

Permalink
Get field on CPU explicitly for show/write_to_file
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Jul 16, 2024
1 parent 22960d9 commit f3a71b9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
7 changes: 5 additions & 2 deletions lasy/laser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from .backend import xp

from axiprop.lib import PropagatorFFT2, PropagatorResampling
from scipy.constants import c

Expand Down Expand Up @@ -342,11 +344,12 @@ def show(self, **kw):
----------
**kw: additional arguments to be passed to matplotlib's imshow command
"""
temporal_field = self.grid.get_temporal_field()
# Get field on CPU
temporal_field = self.grid.get_temporal_field(to_cpu=True)
if self.dim == "rt":
# Show field in the plane y=0, above and below axis, with proper sign for each mode
E = [
xp.concatenate(
np.concatenate(
((-1.0) ** m * temporal_field[m, ::-1], temporal_field[m])
)
for m in self.grid.azimuthal_modes
Expand Down
15 changes: 8 additions & 7 deletions lasy/utils/openpmd_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def write_to_openpmd_file(
Whether the envelope is converted to normalized vector potential
before writing to file.
"""
array = grid.get_temporal_field()
# Get field on CPU
array = grid.get_temporal_field(to_cpu=True)

# Create file
series = io.Series("{}_%05T.{}".format(file_prefix, file_format), io.Access.create)
Expand All @@ -76,7 +77,7 @@ def write_to_openpmd_file(
m.axis_labels = ["t", "r"]

# Store metadata needed to reconstruct the field
m.set_attribute("angularFrequency", 2 * xp.pi * c / wavelength)
m.set_attribute("angularFrequency", 2 * np.pi * c / wavelength)
m.set_attribute("polarization", pol)
if save_as_vector_potential:
m.set_attribute("envelopeField", "normalized_vector_potential")
Expand All @@ -91,20 +92,20 @@ def write_to_openpmd_file(
}

if save_as_vector_potential:
array = field_to_vector_potential(grid, 2 * xp.pi * c / wavelength)
array = field_to_vector_potential(grid, 2 * np.pi * c / wavelength)

# Pick the correct field
if dim == "xyt":
# Switch from x,y,t (internal to lasy) to t,y,x (in openPMD file)
# This is because many PIC codes expect x to be the fastest index
data = xp.transpose(array).copy()
data = np.transpose(array).copy()
elif dim == "rt":
# The representation of modes in openPMD
# (see https://github.com/openPMD/openPMD-standard/blob/latest/STANDARD.md#required-attributes-for-each-mesh-record)
# is different than the representation of modes internal to lasy.
# Thus, there is a non-trivial conversion here
ncomp = 2 * grid.n_azimuthal_modes - 1
data = xp.zeros((ncomp, grid.npoints[0], grid.npoints[1]), dtype=array.dtype)
data = np.zeros((ncomp, grid.npoints[0], grid.npoints[1]), dtype=array.dtype)
data[0, :, :] = array[0, :, :]
for mode in range(1, grid.n_azimuthal_modes):
# cos(m*theta) part of the mode
Expand All @@ -113,12 +114,12 @@ def write_to_openpmd_file(
data[2 * mode, :, :] = -1.0j * array[mode, :, :] + 1.0j * array[-mode, :, :]
# Switch from m,r,t (internal to lasy) to m,t,r (in openPMD file)
# This is because many PIC codes expect r to be the fastest index
data = xp.transpose(data, axes=[0, 2, 1]).copy()
data = np.transpose(data, axes=[0, 2, 1]).copy()

# Define the dataset
dataset = io.Dataset(data.dtype, data.shape)
env = m[io.Mesh_Record_Component.SCALAR]
env.position = xp.zeros(len(dim), dtype=xp.float64)
env.position = np.zeros(len(dim), dtype=np.float64)
env.reset_dataset(dataset)
env.store_chunk(data)

Expand Down

0 comments on commit f3a71b9

Please sign in to comment.