Skip to content

Commit

Permalink
Fix torch.Tensor copy construction warning
Browse files Browse the repository at this point in the history
  • Loading branch information
olemke committed Nov 14, 2024
1 parent b6be86f commit 856c579
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions typhon/retrieval/qrnn/models/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def load_model(f, quantiles):
Returns:
The loaded pytorch model.
"""
model = torch.load(f)
model = torch.load(f, weights_only=False)
return model


Expand Down Expand Up @@ -92,8 +92,8 @@ class BatchedDataset(Dataset):

def __init__(self, training_data, batch_size):
x, y = training_data
self.x = torch.tensor(x, dtype=torch.float)
self.y = torch.tensor(y, dtype=torch.float)
self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float)
self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float)
self.batch_size = batch_size

def __len__(self):
Expand Down

0 comments on commit 856c579

Please sign in to comment.