Skip to content

Commit

Permalink
update for mps
Browse files Browse the repository at this point in the history
  • Loading branch information
dsarrut committed Nov 2, 2023
1 parent 4c24757 commit cef82d8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions gaga_phsp/gaga_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,12 @@ def generate_samples2(
m = 0
z_dim = params["z_dim"]
x_dim = params["x_dim"]
rfake = np.empty((0, x_dim - ncond))
device = params["current_gpu_device"]
current_gpu_mode = params["current_gpu_mode"]
rfake_dtype = np.float64
if current_gpu_mode == 'mps':
rfake_dtype = np.float32
rfake = np.empty((0, x_dim - ncond), dtype=rfake_dtype)
while m < n:
if not silence:
print(f"Batch {m}/{n}")
Expand Down Expand Up @@ -579,7 +582,7 @@ def generate_samples2(
fake = G(z)
# put back to cpu to allow concatenation
fake = fake.cpu().data.numpy()
rfake = np.concatenate((rfake, fake), axis=0)
rfake = np.concatenate((rfake, fake), axis=0, dtype=rfake_dtype)

m = m + current_gpu_batch_size

Expand Down

0 comments on commit cef82d8

Please sign in to comment.