diff --git a/sisl/viz/plots/fatbands.py b/sisl/viz/plots/fatbands.py index 16a6df2593..7a451cc00e 100644 --- a/sisl/viz/plots/fatbands.py +++ b/sisl/viz/plots/fatbands.py @@ -227,6 +227,7 @@ def _read_siesta_output(self, wfsx_file, bands_file, root_fdf): wfsx_sile = self.get_sile(wfsx_file) weights = [] + bands = [] for i, state in enumerate(wfsx_sile.yield_eigenstate(self.H)): # Each eigenstate represents all the states for a given k-point @@ -335,7 +336,7 @@ def _get_groups_weights(self, groups, E0, bands_range, scale): min_band, max_band = bands_range # Get the weights that matter - plot_weights = self.weights.sel(band=slice(min_band, max_band)) + plot_weights = self.weights.sel(band=slice(min_band, max_band - 1)) if groups is None: groups = () diff --git a/sisl/viz/plots/grid.py b/sisl/viz/plots/grid.py index c399381d9f..c16c7d7f6f 100644 --- a/sisl/viz/plots/grid.py +++ b/sisl/viz/plots/grid.py @@ -2,6 +2,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. from collections import defaultdict +from sisl import geometry from sisl.viz.plots.geometry import GeometryPlot import numpy as np from scipy.ndimage import affine_transform @@ -1398,6 +1399,15 @@ class WavefunctionPlot(GridPlot): """ ), + SileInput(key='wfsx_file', name='Path to WFSX file', + dtype=sisl.io.siesta.wfsxSileSiesta, + default=None, + help="""Siesta WFSX file to directly read the coefficients from. + If the root_fdf file is provided but the wfsx one isn't, we will try to find it + as SystemLabel.WFSX. + """ + ), + SislObjectInput(key='geometry', name='Geometry', default=None, dtype=sisl.Geometry, @@ -1440,6 +1450,11 @@ class WavefunctionPlot(GridPlot): 'plot_geom': True } + def __init__(self, *args, **kwargs): + self._index_offset = 0 + + super().__init__(*args, **kwargs) + @entry_point('eigenstate', 0) def _read_nosource(self, eigenstate): """ @@ -1449,8 +1464,28 @@ def _read_nosource(self, eigenstate): raise ValueError('No eigenstate was provided') self.eigenstate = eigenstate - - @entry_point('hamiltonian', 1) + + @entry_point('Siesta WFSX file', 1) + def _read_from_WFSX_file(self, wfsx_file, k, spin, root_fdf): + """Reads the wavefunction coefficients from a SIESTA WFSX file""" + # Try to read the geometry + fdf = self.get_sile(root_fdf or "root_fdf") + if fdf is None: + raise ValueError("The setting 'root_fdf' needs to point to an fdf file with a geometry") + geometry = fdf.read_geometry(output=True) + + # Get the WFSX file. If not provided, it is inferred from the fdf. + wfsx = self.get_sile(wfsx_file or "wfsx_file") + if not wfsx.file.exists(): + raise ValueError(f"File '{wfsx.file}' does not exist.") + + # Try to find the eigenstate that we need + self.eigenstate = wfsx.read_eigenstate(k=k, spin=spin[0], parent=geometry) + if self.eigenstate is None: + # We have not found it. + raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.") + + @entry_point('hamiltonian', 2) def _read_from_H(self, k, spin): """ Calculates the eigenstates from a Hamiltonian and then generates the wavefunctions. @@ -1464,6 +1499,23 @@ def _after_read(self): # calling it later in _set_data pass + def _get_eigenstate(self, i): + + if "index" in self.eigenstate.info: + wf_i = np.nonzero(self.eigenstate.info["index"] == i)[0] + if len(wf_i) == 0: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {self.eigenstate.info['index']}." + f"Entry point used: {self.source._name}") + wf_i = wf_i[0] + else: + max_index = len(self.eigenstate) + if i > max_index: + raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}]." + f"Entry point used: {self.source._name}") + wf_i = i + + return self.eigenstate[wf_i] + def _set_data(self, i, geometry, grid, k, grid_prec, nsc): if geometry is not None: @@ -1498,16 +1550,18 @@ def _set_data(self, i, geometry, grid, k, grid_prec, nsc): tiled_geometry = tiled_geometry.tile(sc_i, ax) nsc[ax] = 1 + is_gamma = (np.array(k) == 0).all() if grid is None: - dtype = np.float64 if (np.array(k) == 0).all() else np.complex128 + dtype = np.float64 if is_gamma else np.complex128 self.grid = sisl.Grid(grid_prec, geometry=tiled_geometry, dtype=dtype) # GridPlot's after_read basically sets the x_range, y_range and z_range options # which need to know what the grid is, that's why we are calling it here super()._after_read() - self.eigenstate[i].wavefunction(self.grid) + state = self._get_eigenstate(i) + state.wavefunction(self.grid) - return super()._set_data(nsc=nsc) + return super()._set_data(nsc=nsc, trace_name=f"WF {i} ({state.eig[0]:.2f} eV)") GridPlot.backends.register_child(WavefunctionPlot.backends)