Skip to content

Commit

Permalink
adding WFSX support to plots (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Nov 8, 2021
1 parent 85d73bd commit c89e5d5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
3 changes: 2 additions & 1 deletion sisl/viz/plots/fatbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _read_siesta_output(self, wfsx_file, bands_file, root_fdf):
wfsx_sile = self.get_sile(wfsx_file)

weights = []
bands = []
for i, state in enumerate(wfsx_sile.yield_eigenstate(self.H)):
# Each eigenstate represents all the states for a given k-point

Expand Down Expand Up @@ -335,7 +336,7 @@ def _get_groups_weights(self, groups, E0, bands_range, scale):
min_band, max_band = bands_range

# Get the weights that matter
plot_weights = self.weights.sel(band=slice(min_band, max_band))
plot_weights = self.weights.sel(band=slice(min_band, max_band - 1))

if groups is None:
groups = ()
Expand Down
64 changes: 59 additions & 5 deletions sisl/viz/plots/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from collections import defaultdict
from sisl import geometry
from sisl.viz.plots.geometry import GeometryPlot
import numpy as np
from scipy.ndimage import affine_transform
Expand Down Expand Up @@ -1398,6 +1399,15 @@ class WavefunctionPlot(GridPlot):
"""
),

SileInput(key='wfsx_file', name='Path to WFSX file',
dtype=sisl.io.siesta.wfsxSileSiesta,
default=None,
help="""Siesta WFSX file to directly read the coefficients from.
If the root_fdf file is provided but the wfsx one isn't, we will try to find it
as SystemLabel.WFSX.
"""
),

SislObjectInput(key='geometry', name='Geometry',
default=None,
dtype=sisl.Geometry,
Expand Down Expand Up @@ -1440,6 +1450,11 @@ class WavefunctionPlot(GridPlot):
'plot_geom': True
}

def __init__(self, *args, **kwargs):
self._index_offset = 0

super().__init__(*args, **kwargs)

@entry_point('eigenstate', 0)
def _read_nosource(self, eigenstate):
"""
Expand All @@ -1449,8 +1464,28 @@ def _read_nosource(self, eigenstate):
raise ValueError('No eigenstate was provided')

self.eigenstate = eigenstate

@entry_point('hamiltonian', 1)

@entry_point('Siesta WFSX file', 1)
def _read_from_WFSX_file(self, wfsx_file, k, spin, root_fdf):
"""Reads the wavefunction coefficients from a SIESTA WFSX file"""
# Try to read the geometry
fdf = self.get_sile(root_fdf or "root_fdf")
if fdf is None:
raise ValueError("The setting 'root_fdf' needs to point to an fdf file with a geometry")
geometry = fdf.read_geometry(output=True)

# Get the WFSX file. If not provided, it is inferred from the fdf.
wfsx = self.get_sile(wfsx_file or "wfsx_file")
if not wfsx.file.exists():
raise ValueError(f"File '{wfsx.file}' does not exist.")

# Try to find the eigenstate that we need
self.eigenstate = wfsx.read_eigenstate(k=k, spin=spin[0], parent=geometry)
if self.eigenstate is None:
# We have not found it.
raise ValueError(f"A state with k={k} was not found in file {wfsx.file}.")

@entry_point('hamiltonian', 2)
def _read_from_H(self, k, spin):
"""
Calculates the eigenstates from a Hamiltonian and then generates the wavefunctions.
Expand All @@ -1464,6 +1499,23 @@ def _after_read(self):
# calling it later in _set_data
pass

def _get_eigenstate(self, i):

if "index" in self.eigenstate.info:
wf_i = np.nonzero(self.eigenstate.info["index"] == i)[0]
if len(wf_i) == 0:
raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available indices: {self.eigenstate.info['index']}."
f"Entry point used: {self.source._name}")
wf_i = wf_i[0]
else:
max_index = len(self.eigenstate)
if i > max_index:
raise ValueError(f"Wavefunction with index {i} is not present in the eigenstate. Available range: [0, {max_index}]."
f"Entry point used: {self.source._name}")
wf_i = i

return self.eigenstate[wf_i]

def _set_data(self, i, geometry, grid, k, grid_prec, nsc):

if geometry is not None:
Expand Down Expand Up @@ -1498,16 +1550,18 @@ def _set_data(self, i, geometry, grid, k, grid_prec, nsc):
tiled_geometry = tiled_geometry.tile(sc_i, ax)
nsc[ax] = 1

is_gamma = (np.array(k) == 0).all()
if grid is None:
dtype = np.float64 if (np.array(k) == 0).all() else np.complex128
dtype = np.float64 if is_gamma else np.complex128
self.grid = sisl.Grid(grid_prec, geometry=tiled_geometry, dtype=dtype)

# GridPlot's after_read basically sets the x_range, y_range and z_range options
# which need to know what the grid is, that's why we are calling it here
super()._after_read()

self.eigenstate[i].wavefunction(self.grid)
state = self._get_eigenstate(i)
state.wavefunction(self.grid)

return super()._set_data(nsc=nsc)
return super()._set_data(nsc=nsc, trace_name=f"WF {i} ({state.eig[0]:.2f} eV)")

GridPlot.backends.register_child(WavefunctionPlot.backends)

0 comments on commit c89e5d5

Please sign in to comment.