Skip to content

Commit

Permalink
BUG: Add pickle support for itk.Matrix and itk.ImageRegion
Browse files Browse the repository at this point in the history
For dask. Issue #4267.
  • Loading branch information
thewtex committed Nov 10, 2023
1 parent 7f3f616 commit 4dec791
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
20 changes: 20 additions & 0 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,16 @@ str = str
msg << "swig_name(" << self->GetIndex() << ", " << self->GetSize() << ")";
return msg.str();
}
%pythoncode %{
def __getstate__(self):
"""Get object state, necessary for serialization with pickle."""
state = (tuple(self.GetIndex()), tuple(self.GetSize()))
return state

def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
self.__init__(*state)
%}
}

%enddef
Expand Down Expand Up @@ -798,6 +808,16 @@ str = str
array = itk.array_from_matrix(self)
return np.asarray(array, dtype=dtype)
def __getstate__(self):
"""Get object state, necessary for serialization with pickle."""
import itk
state = itk.array_from_matrix(self)
return state
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
matrix = itk.matrix_from_array(state)
self.__init__(matrix)
%}
}
Expand Down
9 changes: 9 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def custom_callback(name, progress):
)
assert np.sum(comparison) == 0.0

matrix = itk.Matrix[itk.D, 2, 2]()
matrix.SetIdentity()
serialize_deserialize = pickle.loads(pickle.dumps(matrix))
assert np.array_equal(np.asarray(matrix), np.asarray(serialize_deserialize))

region = itk.ImageRegion[2]([7,8], [2,3])
serialize_deserialize = pickle.loads(pickle.dumps(region))
assert region == serialize_deserialize

# Make sure we can read unsigned short, unsigned int, and cast
image = itk.imread(filename, itk.UI)
assert type(image) == itk.Image[itk.UI, 2]
Expand Down

0 comments on commit 4dec791

Please sign in to comment.