diff --git a/src/koopmans/utils/_xml.py b/src/koopmans/utils/_xml.py index f62a92b5b..e13401636 100644 --- a/src/koopmans/utils/_xml.py +++ b/src/koopmans/utils/_xml.py @@ -35,9 +35,20 @@ def read_xml_nr(xml_file: Path, string: str = 'EFFECTIVE-POTENTIAL') -> List[int return _get_nr(branch) -def read_xml_array(xml_file: Path, norm_const: float, string: str = 'EFFECTIVE-POTENTIAL') -> np.ndarray: +def read_xml_array( + xml_file: Path, norm_const: float, string: str = 'EFFECTIVE-POTENTIAL', retain_final_element: bool=False + ) -> np.ndarray: """ - Loads an array from an xml file + Loads an array from an xml file. + + :param xml_file: The xml file to read from + :param norm_const: The normalization constant to multiply the array with (in our case 1/((Bohr radii)^3) + :param string: The name of the field in the xml file that contains the array, in our case either + 'EFFECTIVE-POTENTIAL' or 'CHARGE-DENSITY' + :param retain_final_element: If True, the array is returned in with periodic boundary conditions, i.e. the last + element in each dimension is equal to the first element in each dimension. This is required for the xsf format. + + :return: The array """ # Load the branch of the xml tree @@ -60,4 +71,8 @@ def read_xml_array(xml_file: Path, norm_const: float, string: str = 'EFFECTIVE-P array_xml[k, j, i] = rho_tmp[(j % (nr_xml[1]-1))*(nr_xml[0]-1)+(i % (nr_xml[0]-1))] array_xml *= norm_const - return array_xml[:-1, :-1, :-1] + if retain_final_element: + # the xsf format requires an array where the last element is equal to the first element in each dimension + return array_xml + else: + return array_xml[:-1, :-1, :-1] diff --git a/src/koopmans/utils/_xsf.py b/src/koopmans/utils/_xsf.py index c166ac8d8..2024e2bc4 100644 --- a/src/koopmans/utils/_xsf.py +++ b/src/koopmans/utils/_xsf.py @@ -28,7 +28,7 @@ def write_xsf(filename: Path, atoms: Atoms, arrays: List[np.ndarray], nr_xml: Tu out.write('CRYSTAL\n\n') out.write('PRIMVEC\n\n') for vec in cell_parameters: - out.write("\t" + " ".join([f"{x:13.10f}" for x in vec])) + out.write("\t" + " ".join([f"{x:13.10f}" for x in vec])+ " \n") out.write('PRIMCOORD\n') out.write(f"\t{len(symbols)}\t1\n") for symbol, pos in zip(symbols, positions):