Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU device placement issues #610

Merged
merged 2 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
warn_on_invalid_x_for_snpec_leakage,
)
from sbi.utils.sbiutils import get_simulations_since_round
from sbi.utils.torchutils import process_device
from sbi.utils.torchutils import process_device, check_if_prior_on_device
from sbi.utils.user_input_checks import prepare_for_sbi


Expand Down Expand Up @@ -117,8 +117,9 @@ def __init__(
0.14.0 is more mature, we will remove this argument.
"""

self._prior = prior
self._device = process_device(device)
check_if_prior_on_device(self._device, prior)
self._prior = prior

if unused_args:
warn(
Expand Down
13 changes: 13 additions & 0 deletions sbi/neural_nets/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor, nn, relu

from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_embedding_net_device, check_data_device


class StandardizeInputs(nn.Module):
Expand Down Expand Up @@ -114,6 +115,10 @@ def build_linear_classifier(
Neural network.
"""

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)
LouisRouillard marked this conversation as resolved.
Show resolved Hide resolved

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
Expand Down Expand Up @@ -161,6 +166,10 @@ def build_mlp_classifier(
Neural network.
"""

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
Expand Down Expand Up @@ -216,6 +225,10 @@ def build_resnet_classifier(
Neural network.
"""

check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net_x, datum=batch_y)
check_embedding_net_device(embedding_net=embedding_net_y, datum=batch_y)

# Infer the output dimensionalities of the embedding_net by making a forward pass.
x_numel = embedding_net_x(batch_x[:1]).numel()
y_numel = embedding_net_y(batch_y[:1]).numel()
Expand Down
7 changes: 7 additions & 0 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
z_score_parser,
)
from sbi.utils.torchutils import create_alternating_binary_mask
from sbi.utils.user_input_checks import check_embedding_net_device, check_data_device


def build_made(
Expand Down Expand Up @@ -52,6 +53,8 @@ def build_made(
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

if x_numel == 1:
Expand Down Expand Up @@ -123,6 +126,8 @@ def build_maf(
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

if x_numel == 1:
Expand Down Expand Up @@ -205,6 +210,8 @@ def build_nsf(
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

# If x is just a scalar then use a dummy mask and learn spline parameters using the
Expand Down
4 changes: 4 additions & 0 deletions sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import sbi.utils as utils

from sbi.utils.user_input_checks import check_embedding_net_device, check_data_device


def build_mdn(
batch_x: Tensor = None,
Expand Down Expand Up @@ -43,6 +45,8 @@ def build_mdn(
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()

transform = transforms.IdentityTransform()
Expand Down
57 changes: 35 additions & 22 deletions sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,47 @@ def process_device(device: str) -> str:
Throws an AssertionError if the prior is not matching the training device not.
"""

if not device == "cpu":
assert device.startswith("cuda"), f"Invalid device string: {device}."
try:
torch.zeros(1).to(device)
warnings.warn(
"""GPU was selected as a device for training the neural network. Note
that we expect **no** significant speed ups in training for the
default architectures we provide. Using the GPU will be effective
only for large neural networks with operations that are fast on the
GPU, e.g., for a CNN or RNN `embedding_net`."""
if device == "cpu":
return "cpu"
else:
warnings.warn(
"GPU was selected as a device for training the neural network. "
"Note that we expect **no** significant speed ups in training for the "
"default architectures we provide. Using the GPU will be effective "
"only for large neural networks with operations that are fast on the "
"GPU, e.g., for a CNN or RNN `embedding_net`."
)
current_gpu_index = torch.cuda.current_device()
if device == "cuda":
return f"cuda:{current_gpu_index}"
LouisRouillard marked this conversation as resolved.
Show resolved Hide resolved
else:
assert device == f"cuda:{current_gpu_index}", (
f"Unrecognized device {device}, "
"should be one of [`cpu`, `cuda`, f`cuda:{index}`]"
)
except (RuntimeError, AssertionError):
warnings.warn(f"Device {device} not available, falling back to CPU.")
device = "cpu"
return device

return device

def check_if_prior_on_device(device: torch.device, prior: Optional[Any] = None) -> None:
"""Try to sample from the prior, and check that the returned data is on the correct
trainin device. If the prior is `None`, simplys pass.

def check_if_prior_on_device(device, prior: Optional[Any] = None):
if prior is not None:
Args:
device: target torch training device
prior: any simulator outputing torch `Tensor`
"""
if prior is None:
pass
else:
prior_device = prior.sample((1,)).device
training_device = torch.zeros(1, device=device).device
assert (
prior_device == training_device
), f"""Prior ({prior_device}) device must match training device (
{training_device}). When training on GPU make sure to pass a prior
initialized on the GPU as well, e.g., `prior = torch.distributions.Normal
(torch.zeros(2, device='cuda'), scale=1.0)`."""
assert prior_device == training_device, (
f"Prior device '{prior_device}' must match training device "
f"'{training_device}'. When training on GPU make sure to "
"pass a prior initialized on the GPU as well, e.g., "
"prior = torch.distributions.Normal"
"(torch.zeros(2, device='cuda'), scale=1.0)`."
)


def tile(x, n):
Expand Down
68 changes: 60 additions & 8 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def check_prior_batch_dims(prior) -> None:
using `torch.distributions.Independent` to reinterpret batch dimensions as
event dimensions, or use the `MultipleIndependent` distribution we provide.

To use `sbi.utils.MultipleIndependent`, just pass a list of priors, e.g. to
To use `sbi.utils.MultipleIndependent`, just pass a list of priors, e.g. to
specify a uniform prior over two parameters, pass as prior:
prior = [
Uniform(torch.zeros(1), torch.ones(1)),
Expand Down Expand Up @@ -269,7 +269,7 @@ def check_for_possibly_batched_x_shape(x_shape):
of sbi might not provide stable support for this and result in
shape mismatches.

NOTE: below we use list notation to reduce clutter, but `x` should be of
NOTE: below we use list notation to reduce clutter, but `x` should be of
type torch.Tensor or ndarray.

For example:
Expand Down Expand Up @@ -327,7 +327,7 @@ def check_prior_attributes(prior) -> None:
except: # Catch any other error.
raise ValueError(
f"""Something went wrong when sampling a batch of parameters
from the prior as `prior.sample(({num_samples}, ))`. Consider using a
from the prior as `prior.sample(({num_samples}, ))`. Consider using a
PyTorch distribution."""
)
try:
Expand Down Expand Up @@ -395,6 +395,51 @@ def check_prior_support(prior):
)


def check_embedding_net_device(embedding_net: nn.Module, datum: torch.Tensor) -> None:
janfb marked this conversation as resolved.
Show resolved Hide resolved
"""Checks if the device for the `embedding_net`'s weights is the same as the device
for the fed `datum`. In case of discrepancy, warn the user and move the
embedding_net` to the `datum`'s device.

Args:
embedding_net: torch `Module` embedding data
datum torch `Tensor` from the training device
"""
datum_device = datum.device
embedding_net_devices = [p.device for p in embedding_net.parameters()]
if len(embedding_net_devices) > 0:
embedding_net_device = embedding_net_devices[0]
if embedding_net_device != datum_device:
warnings.warn(
"Mismatch between the device of the data fed "
"to the embedding_net and the device of the "
"embedding_net's weights. Fed data has device "
f"'{datum_device}' vs embedding_net weights have "
f"device '{embedding_net_device}'. "
"Automatically switching the embedding_net's device to "
f"'{datum_device}', which could otherwise be done manually "
f"""using the line `embedding_net.to('{datum_device}')`."""
)
embedding_net.to(datum_device)
else:
pass


def check_data_device(datum_1: torch.Tensor, datum_2: torch.Tensor) -> None:
"""Checks if two tensors have the seme device. Fails if there is a device
discrepancy

Args:
datum_1: torch `Tensor`
datum_2: torch `Tensor`
"""
assert datum_1.device == datum_2.device, (
"Mismatch in fed data's device: "
f"datum_1 has device '{datum_1.device}' whereas "
f"datum_2 has device '{datum_2.device}'. Please "
"use data from a common device."
)


def process_simulator(
user_simulator: Callable,
prior,
Expand Down Expand Up @@ -629,13 +674,20 @@ def validate_theta_and_x(
assert theta.dtype == float32, "Type of parameters must be float32."
assert x.dtype == float32, "Type of simulator outputs must be float32."

simulations_device = f"{x.device.type}:{x.device.index}"
if "cpu" not in simulations_device and "cpu" in training_device:
logging.warning(
f"""Simulations are on {simulations_device} but training device is
set to {training_device}, moving data to device to {training_device}."""
if str(x.device) != training_device:
LouisRouillard marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"Data x has device '{x.device}' "
f"different from the training_device '{training_device}', "
f"moving x to the training_device '{training_device}'."
)
x = x.to(training_device)

if str(theta.device) != training_device:
warnings.warn(
f"Parameters theta has device '{theta.device}' "
f"different from the training_device '{training_device}', "
f"moving theta to the training_device '{training_device}'."
)
theta = theta.to(training_device)

return theta, x
Expand Down
Loading