Skip to content

Commit

Permalink
Merge branch 'cupy_backend' of github.com:RemiLehe/lasy into cupy_bac…
Browse files Browse the repository at this point in the history
…kend
  • Loading branch information
RemiLehe committed Jul 16, 2024
2 parents e606c25 + 98f2d9f commit e04af51
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions lasy/utils/grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from lasy.backend import xp, use_cupy

from lasy.backend import use_cupy, xp

time_axis_indx = -1

Expand Down Expand Up @@ -79,8 +80,8 @@ def set_temporal_field(self, field):
assert field.shape == self.temporal_field.shape
assert field.dtype == "complex128"
if use_cupy and type(field) == np.ndarray:
field = xp.asarray(field) # Copy to GPU
self.temporal_field[:,:,:] = field
field = xp.asarray(field) # Copy to GPU
self.temporal_field[:, :, :] = field
self.temporal_field_valid = True
self.spectral_field_valid = False # Invalidates the spectral field

Expand All @@ -96,8 +97,8 @@ def set_spectral_field(self, field):
assert field.shape == self.spectral_field.shape
assert field.dtype == "complex128"
if use_cupy and type(field) == np.ndarray:
field = xp.asarray(field) # Copy to GPU
self.spectral_field[:,:,:] = field
field = xp.asarray(field) # Copy to GPU
self.spectral_field[:, :, :] = field
self.spectral_field_valid = True
self.temporal_field_valid = False # Invalidates the temporal field

Expand Down

0 comments on commit e04af51

Please sign in to comment.