diff --git a/molexpress/datasets/encoders.py b/molexpress/datasets/encoders.py index fa9cfa2..2850a5b 100644 --- a/molexpress/datasets/encoders.py +++ b/molexpress/datasets/encoders.py @@ -103,7 +103,10 @@ def __call__(self, molecule: types.Molecule) -> np.ndarray: return {"edge_src": edge_src, "edge_dst": edge_dst} if molecule.GetNumBonds() == 0: - edge_state = np.zeros(shape=(0, self.output_dim), dtype=self.output_dtype) + edge_state = np.zeros( + shape=(0, self.output_dim + int(self.self_loops)), + dtype=self.output_dtype + ) return { "edge_src": edge_src, "edge_dst": edge_dst,