Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel calculation of PDOS and (fat)bands #367

Merged
merged 4 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sisl/viz/backends/templates/_plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def draw_bands(self, filtered_bands, line, spindown_line, spin, spin_texture, ad
else:
draw_band_func = self._draw_band

if "spin" not in filtered_bands.coords:
filtered_bands = filtered_bands.expand_dims("spin")
# Now loop through all bands to draw them
for spin_bands, ispin in zip(filtered_bands.transpose('spin', 'band', 'k'), filtered_bands.spin.values):
for ispin, spin_bands in enumerate(filtered_bands.transpose('spin', 'band', 'k')):
line_style = line
if ispin == 1:
line_style.update(spindown_line)
Expand Down
6 changes: 4 additions & 2 deletions sisl/viz/backends/templates/_plots/fatbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def draw_group_weights(self, weights, metadata, name, bands, x):
bands: xarray.DataArray with indices (spin, band, k)
Contains all the eigenvalues of the band structure.
"""
if "spin" not in bands.coords:
bands = bands.expand_dims("spin")

for ispin, spin_weights in enumerate(weights):
for ispin, spin_weights in enumerate(weights.transpose("spin", "band", "k")):
for i, band_weights in enumerate(spin_weights):
band_values = bands.sel(band=band_weights.band, spin=band_weights.spin)
band_values = bands.sel(band=band_weights.band, spin=ispin)

self._draw_band_weights(
x=x, y=band_values, weights=band_weights.values,
Expand Down
6 changes: 3 additions & 3 deletions sisl/viz/input_fields/dropdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ class SpinSelect(DropdownInput):
}

_options = {
Spin.UNPOLARIZED: [{"label": "Total", "value": 0}],
Spin.UNPOLARIZED: [],
Spin.POLARIZED: [{"label": "↑", "value": 0}, {"label": "↓", "value": 1}],
Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("sum", "x", "y", "z")],
Spin.SPINORBIT: [{"label": val, "value": val} for val in ("sum", "x", "y", "z")]
Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("total", "x", "y", "z")],
Spin.SPINORBIT: [{"label": val, "value": val} for val in ("total", "x", "y", "z")]
}

def __init__(self, *args, only_if_polarized=False, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion sisl/viz/input_fields/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def get_options(self, key, **kwargs):

# If "spin" was one of the keys, we are going to incorporate the spin options, taking into
# account the position (column index) where they are expected to be returned.
if spin_in_keys:
if spin_in_keys and len(spin_options) > 0:
options = np.concatenate([np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options])

# Squeeze the options array, just in case there is only one key
Expand Down
150 changes: 84 additions & 66 deletions sisl/viz/plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
import itertools

import numpy as np
import xarray as xr

import sisl
from sisl.messages import warn
from ..plot import Plot, entry_point
from ..plotutils import call_method_if_present, find_files
from ..plotutils import find_files
from ..input_fields import (
TextInput, SwitchInput, ColorPicker, DropdownInput,
IntegerInput, FloatInput, RangeInput, RangeSlider,
QueriesInput, ProgramaticInput, FunctionInput, SileInput,
PlotableInput, SpinSelect, AiidaNodeInput, BandStructureInput
TextInput, SwitchInput, ColorPicker,
FloatInput, RangeSlider,
QueriesInput, FunctionInput, SileInput,
SpinSelect, AiidaNodeInput, BandStructureInput
)
from ..input_fields.range import ErangeInput

try:
import pathos
_do_parallel_calc = True
except:
_do_parallel_calc = False

class BandsPlot(Plot):
"""
Expand Down Expand Up @@ -125,14 +132,6 @@ class BandsPlot(Plot):
help="""An aiida BandsData node."""
),

FunctionInput(key="eigenstate_map", name="Eigenstate map function",
default=None,
positional=["eigenstate", "plot"],
returns=[],
help="""This function receives the eigenstate object for each k value when the bands are being extracted from a hamiltonian.
You can do whatever you want with it, the point of this function is to avoid running the diagonalization process twice."""
),

FunctionInput(key="add_band_data", name="Add band data function",
default=lambda band, plot: {},
positional=["band", "plot"],
Expand Down Expand Up @@ -281,6 +280,14 @@ def _get_frame_names(self):
return [childPlot.get_setting("bands_file").name for childPlot in self.child_plots]

return cls.animated("bands_file", bands_files, frame_names = _get_frame_names, wdir = wdir, **kwargs)

@property
def bands(self):
return self.bands_data["E"]

@property
def spin_moments(self):
return self.bands_data["spin_moments"]

def _after_init(self):
self.spin = sisl.Spin("")
Expand Down Expand Up @@ -308,7 +315,7 @@ def _read_aiida_bands(self, aiida_bands):
tick_info["ticklabels"].append(label)

# Construct the dataarray
self.bands = xr.DataArray(
self.bands_data = xr.DataArray(
bands,
coords={
"spin": np.arange(0, bands.shape[0]),
Expand All @@ -320,12 +327,10 @@ def _read_aiida_bands(self, aiida_bands):
)

@entry_point('band structure')
def _read_from_H(self, band_structure, eigenstate_map):
def _read_from_H(self, band_structure, extra_vars=()):
"""
Uses a sisl's `BandStructure` object to calculate the bands.
"""
import xarray as xr

if band_structure is None:
raise ValueError("No band structure (k points path) was provided")

Expand All @@ -340,24 +345,21 @@ def _read_from_H(self, band_structure, eigenstate_map):

self.ticks = band_structure.lineartick()

# We define a wrapper to get the values out of the eigenstates
# to give the possibility to the user to do something inbetween
# NOTE THAT THIS IS USED BY FAT BANDS TO GET THE WEIGHTS SIMULTANEOUSLY
eig_map = eigenstate_map

# Also, in this wrapper we will get the spin moments in case it is a non_colinear
# or spin-orbit calculation
if self.spin.is_noncolinear or self.spin.is_spinorbit:
self.spin_moments = []
elif hasattr(self, "spin_moments"):
del self.spin_moments
# In case it is a non_colinear or spin-orbit calculation we will get the spin moments
if not self.spin.is_diagonal:
def _spin_moment_getter(eigenstate, plot, spin):
return eigenstate.spin_moment().real

extra_vars = ({
"coords": ("band", "axis"), "coords_values": dict(axis=["x", "y", "z"]),
"name": "spin_moments", "getter": _spin_moment_getter},
*extra_vars)

def bands_wrapper(eigenstate, spin_index):
if callable(eig_map):
eig_map(eigenstate, self, spin_index)
if hasattr(self, "spin_moments"):
self.spin_moments.append(eigenstate.spin_moment())
return eigenstate.eig
returns = []
for extra_var in extra_vars:
returns.append(extra_var["getter"](eigenstate, self, spin_index))
return (eigenstate.eig, *returns)

# Define the available spins
spin_indices = [0]
Expand All @@ -366,38 +368,36 @@ def bands_wrapper(eigenstate, spin_index):

# Get the eigenstates for all the available spin components
bands_arrays = []
name = ["E"]
coords = [('band'), ]
coords_values = {"spin": spin_indices, "k": band_structure.lineark()}
for extra_var in extra_vars:
name.append(extra_var["name"])
coords.append(extra_var["coords"])
coords_values.update(extra_var.get("coords_values", {}))

for spin_index in spin_indices:

# Non collinear routines don't accept the keyword argument "spin"
spin_kwarg = {"spin": spin_index}
if self.spin.is_noncolinear:
if not self.spin.is_diagonal:
spin_kwarg = {}

spin_bands = band_structure.apply.dataarray.eigenstate(
wrap=partial(bands_wrapper, spin_index=spin_index),
**spin_kwarg,
coords=('band',),
)

with band_structure.apply(pool=_do_parallel_calc, unzip=True) as parallel:
pfebrer marked this conversation as resolved.
Show resolved Hide resolved
spin_bands = parallel.dataarray.eigenstate(
wrap=partial(bands_wrapper, spin_index=spin_index),
**spin_kwarg,
coords=coords, name=name,
)

bands_arrays.append(spin_bands)

# Merge everything into a single dataarray with a spin dimension
self.bands = xr.concat(bands_arrays, "spin").assign_coords({"spin": spin_indices}).transpose("k", "spin", "band")
# Merge everything into a single dataset with a spin dimension
self.bands_data = xr.concat(bands_arrays, "spin").assign_coords(coords_values)

self.bands['k'] = band_structure.lineark()
# Inform of where to place the ticks
self.bands.attrs = {"ticks": self.ticks[0], "ticklabels": self.ticks[1], **bands_arrays[0].attrs}

if hasattr(self, "spin_moments"):
self.spin_moments = xr.DataArray(
self.spin_moments,
coords={
"k": self.bands.k,
"band": self.bands.band,
"axis": ["x", "y", "z"]
},
dims=("k", "band", "axis")
)
self.bands_data.attrs = {"ticks": self.ticks[0], "ticklabels": self.ticks[1], **bands_arrays[0].attrs}

@entry_point('bands file')
def _read_siesta_output(self, bands_file, band_structure):
Expand All @@ -407,13 +407,22 @@ def _read_siesta_output(self, bands_file, band_structure):
if band_structure:
raise ValueError("A path was provided, therefore we can not use the .bands file even if there is one")

self.bands = self.get_sile(bands_file or "bands_file").read_data(as_dataarray=True)
self.bands_data = self.get_sile(bands_file or "bands_file").read_data(as_dataarray=True)

# Define the spin class of the results we have retrieved
if len(self.bands.spin.values) == 2:
if len(self.bands_data.spin.values) == 2:
self.spin = sisl.Spin("p")

def _after_read(self):
if isinstance(self.bands_data, xr.DataArray):
attrs = self.bands_data.attrs
self.bands_data = xr.Dataset({"E": self.bands_data})
self.bands_data.attrs = attrs

# If the calculation is not spin polarized it makes no sense to
# retain a spin index
if "spin" in self.bands_data and not self.spin.is_polarized:
self.bands_data = self.bands_data.sel(spin=self.bands_data.spin[0], drop=True)

# Inform the spin input of what spin class are we handling
self.get_param("spin").update_options(self.spin)
Expand Down Expand Up @@ -461,15 +470,17 @@ def _set_data(self, Erange, E0, bands_range, spin, spin_texture_colorscale, band
self.update_settings(run_updates=False, bands_range=[int(filtered_bands['band'].min()), int(filtered_bands['band'].max())], no_log=True)

# Give the filtered bands the same attributes as the full bands
filtered_bands.attrs = self.bands.attrs
filtered_bands.attrs = self.bands_data.attrs

# Let's treat the spin if the user requested it
self.spin_texture = False
if spin is not None and len(spin) > 0:
if isinstance(spin[0], int):
filtered_bands = filtered_bands.sel(spin=spin)
# Only use the spin setting if there is a spin index
if "spin" in filtered_bands.coords:
filtered_bands = filtered_bands.sel(spin=spin)
elif isinstance(spin[0], str):
if not hasattr(self, "spin_moments"):
if "spin_moments" not in self.bands_data:
raise ValueError(f"You requested spin texture ({spin[0]}), but spin moments have not been calculated. The spin class is {self.spin.kind}")
self.spin_texture = True

Expand Down Expand Up @@ -559,7 +570,9 @@ def clear_equivalent(ks):
if requested_spin is None:
requested_spin = [0, 1]

for spin in self.bands.spin:
avail_spins = self.bands_data.get("spin", [0])

for spin in avail_spins:
if spin in requested_spin:
from_k = custom_gap["from"]
to_k = custom_gap["to"]
Expand Down Expand Up @@ -592,9 +605,9 @@ def _sanitize_k(self, k):
try:
san_k = float(k)
except ValueError:
if k in self.bands.attrs["ticklabels"]:
i_tick = self.bands.attrs["ticklabels"].index(k)
san_k = self.bands.attrs["ticks"][i_tick]
if k in self.bands_data.attrs["ticklabels"]:
i_tick = self.bands_data.attrs["ticklabels"].index(k)
san_k = self.bands_data.attrs["ticks"][i_tick]
else:
pass
# raise ValueError(f"We can not interpret {k} as a k-location in the current bands plot")
Expand Down Expand Up @@ -638,7 +651,8 @@ def _get_gap_coords(self, from_k, to_k=None, gap_spin=0, **kwargs):
ks[i] = self._sanitize_k(val)

VB, CB = self.gap_info["bands"]
Es = [self.bands.sel(k=k, band=band, spin=gap_spin, method="nearest") for k, band in zip(ks, (VB, CB))]
spin_bands = self.bands.sel(spin=gap_spin) if "spin" in self.bands.coords else self.bands
Es = [spin_bands.sel(k=k, band=band, method="nearest") for k, band in zip(ks, (VB, CB))]
# Get the real values of ks that have been obtained
# because we might not have exactly the ks requested
ks = [np.ravel(E.k)[0] for E in Es]
Expand Down Expand Up @@ -761,7 +775,11 @@ def effective_mass(self, band, k, k_direction, band_spin=0, n_points=10):
from sisl.unit.base import units

# Get the band that we want to fit
band_vals = self.bands.sel(band=band, spin=band_spin)
bands = self.bands
if "spin" in bands.coords:
band_vals = bands.sel(band=band, spin=band_spin)
else:
band_vals = bands.sel(band=band)

# Sanitize k to a float
k = self._sanitize_k(k)
Expand All @@ -780,7 +798,7 @@ def effective_mass(self, band, k, k_direction, band_spin=0, n_points=10):

# Grab the slice of the band that we are going to fit
sel_band = band_vals[sel_slice] * units("eV", "Hartree")
sel_k = self.bands.k[sel_slice] - k
sel_k = bands.k[sel_slice] - k

# Fit the band to a second order polynomial
polyfit = np.polynomial.Polynomial.fit(sel_k, sel_band, 2)
Expand Down
Loading