diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 4f4d134..7e10b5e 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -735,7 +735,7 @@ def _init_buffer(self, name: str, shape: Tuple, set_value: Optional[float] = Non if not flag and tensor.size(i) != d: flag = True - if not flag and check_device and tensor.device != self.device: + if not flag and check_device and self.device.index is not None and tensor.device != self.device: flag = True if flag: