From 8d3c6de1d4cbc06c2d9c89ceb9ebf9ba325263d1 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 10 Jul 2024 15:23:10 -0700 Subject: [PATCH] TN: fix unintended mutation when updating parameters --- quimb/tensor/tensor_core.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index 6e12dbea..75694d22 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -1408,12 +1408,19 @@ def set_params(self, params): data array. This is mainly for providing an interface for 'structured' arrays e.g. with block sparsity to interact with optimization. """ - if hasattr(self.data, "set_params"): - self.data.set_params(params) - elif hasattr(self.data, "params"): - self.data.params = params + data = self.data + if hasattr(data, "set_params"): + # Tensor don't modify their data inplace + data = data.copy() + data.set_params(params) + elif hasattr(data, "params"): + # Tensor don't modify their data inplace + data = data.copy() + data.params = params else: - self._set_data(params) + data = params + + self._set_data(data) def copy(self, deep=False, virtual=False): """Copy this tensor.