diff --git a/sisl/viz/backends/templates/_plots/bands.py b/sisl/viz/backends/templates/_plots/bands.py index 596aee5812..dccaf766ac 100644 --- a/sisl/viz/backends/templates/_plots/bands.py +++ b/sisl/viz/backends/templates/_plots/bands.py @@ -51,8 +51,10 @@ def draw_bands(self, filtered_bands, line, spindown_line, spin, spin_texture, ad else: draw_band_func = self._draw_band + if "spin" not in filtered_bands.coords: + filtered_bands = filtered_bands.expand_dims("spin") # Now loop through all bands to draw them - for spin_bands, ispin in zip(filtered_bands.transpose('spin', 'band', 'k'), filtered_bands.spin.values): + for ispin, spin_bands in enumerate(filtered_bands.transpose('spin', 'band', 'k')): line_style = line if ispin == 1: line_style.update(spindown_line) diff --git a/sisl/viz/backends/templates/_plots/fatbands.py b/sisl/viz/backends/templates/_plots/fatbands.py index e63c39b933..7d27aeaf01 100644 --- a/sisl/viz/backends/templates/_plots/fatbands.py +++ b/sisl/viz/backends/templates/_plots/fatbands.py @@ -56,10 +56,12 @@ def draw_group_weights(self, weights, metadata, name, bands, x): bands: xarray.DataArray with indices (spin, band, k) Contains all the eigenvalues of the band structure. """ + if "spin" not in bands.coords: + bands = bands.expand_dims("spin") - for ispin, spin_weights in enumerate(weights): + for ispin, spin_weights in enumerate(weights.transpose("spin", "band", "k")): for i, band_weights in enumerate(spin_weights): - band_values = bands.sel(band=band_weights.band, spin=band_weights.spin) + band_values = bands.sel(band=band_weights.band, spin=ispin) self._draw_band_weights( x=x, y=band_values, weights=band_weights.values, diff --git a/sisl/viz/input_fields/dropdown.py b/sisl/viz/input_fields/dropdown.py index 03b292b835..96cf830c1e 100644 --- a/sisl/viz/input_fields/dropdown.py +++ b/sisl/viz/input_fields/dropdown.py @@ -156,10 +156,10 @@ class SpinSelect(DropdownInput): } _options = { - Spin.UNPOLARIZED: [{"label": "Total", "value": 0}], + Spin.UNPOLARIZED: [], Spin.POLARIZED: [{"label": "↑", "value": 0}, {"label": "↓", "value": 1}], - Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("sum", "x", "y", "z")], - Spin.SPINORBIT: [{"label": val, "value": val} for val in ("sum", "x", "y", "z")] + Spin.NONCOLINEAR: [{"label": val, "value": val} for val in ("total", "x", "y", "z")], + Spin.SPINORBIT: [{"label": val, "value": val} for val in ("total", "x", "y", "z")] } def __init__(self, *args, only_if_polarized=False, **kwargs): diff --git a/sisl/viz/input_fields/queries.py b/sisl/viz/input_fields/queries.py index 213f5ec138..ce3f9c3838 100644 --- a/sisl/viz/input_fields/queries.py +++ b/sisl/viz/input_fields/queries.py @@ -299,7 +299,7 @@ def get_options(self, key, **kwargs): # If "spin" was one of the keys, we are going to incorporate the spin options, taking into # account the position (column index) where they are expected to be returned. - if spin_in_keys: + if spin_in_keys and len(spin_options) > 0: options = np.concatenate([np.insert(options, spin_key_i, spin, axis=1) for spin in spin_options]) # Squeeze the options array, just in case there is only one key diff --git a/sisl/viz/plots/bands.py b/sisl/viz/plots/bands.py index 065ff79a87..b132935133 100644 --- a/sisl/viz/plots/bands.py +++ b/sisl/viz/plots/bands.py @@ -6,18 +6,25 @@ import itertools import numpy as np +import xarray as xr import sisl +from sisl.messages import warn from ..plot import Plot, entry_point -from ..plotutils import call_method_if_present, find_files +from ..plotutils import find_files from ..input_fields import ( - TextInput, SwitchInput, ColorPicker, DropdownInput, - IntegerInput, FloatInput, RangeInput, RangeSlider, - QueriesInput, ProgramaticInput, FunctionInput, SileInput, - PlotableInput, SpinSelect, AiidaNodeInput, BandStructureInput + TextInput, SwitchInput, ColorPicker, + FloatInput, RangeSlider, + QueriesInput, FunctionInput, SileInput, + SpinSelect, AiidaNodeInput, BandStructureInput ) from ..input_fields.range import ErangeInput +try: + import pathos + _do_parallel_calc = True +except: + _do_parallel_calc = False class BandsPlot(Plot): """ @@ -125,14 +132,6 @@ class BandsPlot(Plot): help="""An aiida BandsData node.""" ), - FunctionInput(key="eigenstate_map", name="Eigenstate map function", - default=None, - positional=["eigenstate", "plot"], - returns=[], - help="""This function receives the eigenstate object for each k value when the bands are being extracted from a hamiltonian. - You can do whatever you want with it, the point of this function is to avoid running the diagonalization process twice.""" - ), - FunctionInput(key="add_band_data", name="Add band data function", default=lambda band, plot: {}, positional=["band", "plot"], @@ -281,6 +280,14 @@ def _get_frame_names(self): return [childPlot.get_setting("bands_file").name for childPlot in self.child_plots] return cls.animated("bands_file", bands_files, frame_names = _get_frame_names, wdir = wdir, **kwargs) + + @property + def bands(self): + return self.bands_data["E"] + + @property + def spin_moments(self): + return self.bands_data["spin_moments"] def _after_init(self): self.spin = sisl.Spin("") @@ -308,7 +315,7 @@ def _read_aiida_bands(self, aiida_bands): tick_info["ticklabels"].append(label) # Construct the dataarray - self.bands = xr.DataArray( + self.bands_data = xr.DataArray( bands, coords={ "spin": np.arange(0, bands.shape[0]), @@ -320,12 +327,10 @@ def _read_aiida_bands(self, aiida_bands): ) @entry_point('band structure') - def _read_from_H(self, band_structure, eigenstate_map): + def _read_from_H(self, band_structure, extra_vars=()): """ Uses a sisl's `BandStructure` object to calculate the bands. """ - import xarray as xr - if band_structure is None: raise ValueError("No band structure (k points path) was provided") @@ -340,24 +345,21 @@ def _read_from_H(self, band_structure, eigenstate_map): self.ticks = band_structure.lineartick() - # We define a wrapper to get the values out of the eigenstates - # to give the possibility to the user to do something inbetween - # NOTE THAT THIS IS USED BY FAT BANDS TO GET THE WEIGHTS SIMULTANEOUSLY - eig_map = eigenstate_map - - # Also, in this wrapper we will get the spin moments in case it is a non_colinear - # or spin-orbit calculation - if self.spin.is_noncolinear or self.spin.is_spinorbit: - self.spin_moments = [] - elif hasattr(self, "spin_moments"): - del self.spin_moments + # In case it is a non_colinear or spin-orbit calculation we will get the spin moments + if not self.spin.is_diagonal: + def _spin_moment_getter(eigenstate, plot, spin): + return eigenstate.spin_moment().real + + extra_vars = ({ + "coords": ("band", "axis"), "coords_values": dict(axis=["x", "y", "z"]), + "name": "spin_moments", "getter": _spin_moment_getter}, + *extra_vars) def bands_wrapper(eigenstate, spin_index): - if callable(eig_map): - eig_map(eigenstate, self, spin_index) - if hasattr(self, "spin_moments"): - self.spin_moments.append(eigenstate.spin_moment()) - return eigenstate.eig + returns = [] + for extra_var in extra_vars: + returns.append(extra_var["getter"](eigenstate, self, spin_index)) + return (eigenstate.eig, *returns) # Define the available spins spin_indices = [0] @@ -366,38 +368,36 @@ def bands_wrapper(eigenstate, spin_index): # Get the eigenstates for all the available spin components bands_arrays = [] + name = ["E"] + coords = [('band'), ] + coords_values = {"spin": spin_indices, "k": band_structure.lineark()} + for extra_var in extra_vars: + name.append(extra_var["name"]) + coords.append(extra_var["coords"]) + coords_values.update(extra_var.get("coords_values", {})) + for spin_index in spin_indices: # Non collinear routines don't accept the keyword argument "spin" spin_kwarg = {"spin": spin_index} - if self.spin.is_noncolinear: + if not self.spin.is_diagonal: spin_kwarg = {} - - spin_bands = band_structure.apply.dataarray.eigenstate( - wrap=partial(bands_wrapper, spin_index=spin_index), - **spin_kwarg, - coords=('band',), - ) + + with band_structure.apply(pool=_do_parallel_calc, unzip=True) as parallel: + spin_bands = parallel.dataarray.eigenstate( + wrap=partial(bands_wrapper, spin_index=spin_index), + **spin_kwarg, + coords=coords, name=name, + ) bands_arrays.append(spin_bands) - # Merge everything into a single dataarray with a spin dimension - self.bands = xr.concat(bands_arrays, "spin").assign_coords({"spin": spin_indices}).transpose("k", "spin", "band") + # Merge everything into a single dataset with a spin dimension + self.bands_data = xr.concat(bands_arrays, "spin").assign_coords(coords_values) self.bands['k'] = band_structure.lineark() # Inform of where to place the ticks - self.bands.attrs = {"ticks": self.ticks[0], "ticklabels": self.ticks[1], **bands_arrays[0].attrs} - - if hasattr(self, "spin_moments"): - self.spin_moments = xr.DataArray( - self.spin_moments, - coords={ - "k": self.bands.k, - "band": self.bands.band, - "axis": ["x", "y", "z"] - }, - dims=("k", "band", "axis") - ) + self.bands_data.attrs = {"ticks": self.ticks[0], "ticklabels": self.ticks[1], **bands_arrays[0].attrs} @entry_point('bands file') def _read_siesta_output(self, bands_file, band_structure): @@ -407,13 +407,22 @@ def _read_siesta_output(self, bands_file, band_structure): if band_structure: raise ValueError("A path was provided, therefore we can not use the .bands file even if there is one") - self.bands = self.get_sile(bands_file or "bands_file").read_data(as_dataarray=True) + self.bands_data = self.get_sile(bands_file or "bands_file").read_data(as_dataarray=True) # Define the spin class of the results we have retrieved - if len(self.bands.spin.values) == 2: + if len(self.bands_data.spin.values) == 2: self.spin = sisl.Spin("p") def _after_read(self): + if isinstance(self.bands_data, xr.DataArray): + attrs = self.bands_data.attrs + self.bands_data = xr.Dataset({"E": self.bands_data}) + self.bands_data.attrs = attrs + + # If the calculation is not spin polarized it makes no sense to + # retain a spin index + if "spin" in self.bands_data and not self.spin.is_polarized: + self.bands_data = self.bands_data.sel(spin=self.bands_data.spin[0], drop=True) # Inform the spin input of what spin class are we handling self.get_param("spin").update_options(self.spin) @@ -461,15 +470,17 @@ def _set_data(self, Erange, E0, bands_range, spin, spin_texture_colorscale, band self.update_settings(run_updates=False, bands_range=[int(filtered_bands['band'].min()), int(filtered_bands['band'].max())], no_log=True) # Give the filtered bands the same attributes as the full bands - filtered_bands.attrs = self.bands.attrs + filtered_bands.attrs = self.bands_data.attrs # Let's treat the spin if the user requested it self.spin_texture = False if spin is not None and len(spin) > 0: if isinstance(spin[0], int): - filtered_bands = filtered_bands.sel(spin=spin) + # Only use the spin setting if there is a spin index + if "spin" in filtered_bands.coords: + filtered_bands = filtered_bands.sel(spin=spin) elif isinstance(spin[0], str): - if not hasattr(self, "spin_moments"): + if "spin_moments" not in self.bands_data: raise ValueError(f"You requested spin texture ({spin[0]}), but spin moments have not been calculated. The spin class is {self.spin.kind}") self.spin_texture = True @@ -559,7 +570,9 @@ def clear_equivalent(ks): if requested_spin is None: requested_spin = [0, 1] - for spin in self.bands.spin: + avail_spins = self.bands_data.get("spin", [0]) + + for spin in avail_spins: if spin in requested_spin: from_k = custom_gap["from"] to_k = custom_gap["to"] @@ -592,9 +605,9 @@ def _sanitize_k(self, k): try: san_k = float(k) except ValueError: - if k in self.bands.attrs["ticklabels"]: - i_tick = self.bands.attrs["ticklabels"].index(k) - san_k = self.bands.attrs["ticks"][i_tick] + if k in self.bands_data.attrs["ticklabels"]: + i_tick = self.bands_data.attrs["ticklabels"].index(k) + san_k = self.bands_data.attrs["ticks"][i_tick] else: pass # raise ValueError(f"We can not interpret {k} as a k-location in the current bands plot") @@ -638,7 +651,8 @@ def _get_gap_coords(self, from_k, to_k=None, gap_spin=0, **kwargs): ks[i] = self._sanitize_k(val) VB, CB = self.gap_info["bands"] - Es = [self.bands.sel(k=k, band=band, spin=gap_spin, method="nearest") for k, band in zip(ks, (VB, CB))] + spin_bands = self.bands.sel(spin=gap_spin) if "spin" in self.bands.coords else self.bands + Es = [spin_bands.sel(k=k, band=band, method="nearest") for k, band in zip(ks, (VB, CB))] # Get the real values of ks that have been obtained # because we might not have exactly the ks requested ks = [np.ravel(E.k)[0] for E in Es] @@ -761,7 +775,11 @@ def effective_mass(self, band, k, k_direction, band_spin=0, n_points=10): from sisl.unit.base import units # Get the band that we want to fit - band_vals = self.bands.sel(band=band, spin=band_spin) + bands = self.bands + if "spin" in bands.coords: + band_vals = bands.sel(band=band, spin=band_spin) + else: + band_vals = bands.sel(band=band) # Sanitize k to a float k = self._sanitize_k(k) @@ -780,7 +798,7 @@ def effective_mass(self, band, k, k_direction, band_spin=0, n_points=10): # Grab the slice of the band that we are going to fit sel_band = band_vals[sel_slice] * units("eV", "Hartree") - sel_k = self.bands.k[sel_slice] - k + sel_k = bands.k[sel_slice] - k # Fit the band to a second order polynomial polyfit = np.polynomial.Polynomial.fit(sel_k, sel_band, 2) diff --git a/sisl/viz/plots/fatbands.py b/sisl/viz/plots/fatbands.py index de2a9e82ab..8b5137a456 100644 --- a/sisl/viz/plots/fatbands.py +++ b/sisl/viz/plots/fatbands.py @@ -2,7 +2,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. import numpy as np -from xarray import DataArray +from xarray import DataArray, Dataset import sisl from ..plot import entry_point @@ -186,6 +186,10 @@ class FatbandsPlot(BandsPlot): ) + @property + def weights(self): + return self.bands_data["weight"] + @entry_point("siesta output") def _read_siesta_output(self, wfsx_file, bands_file, root_fdf): """Generates fatbands from SIESTA output. @@ -235,7 +239,7 @@ def _read_siesta_output(self, wfsx_file, bands_file, root_fdf): weights = np.array(weights).real # Finally, build the weights dataarray so that it can be used by _set_data - self.weights = DataArray( + weights = DataArray( weights, coords={ "k": self.bands.k, @@ -247,7 +251,12 @@ def _read_siesta_output(self, wfsx_file, bands_file, root_fdf): # Add the spin dimension so that the weights array is normalized, # even though spin is not yet supported by this entrypoint - self.weights = self.weights.expand_dims("spin") + weights = weights.expand_dims("spin") + + # Merge everything into a dataset + attrs = self.bands_data.attrs + self.bands_data = Dataset({"E": self.bands_data, "weight": weights}) + self.bands_data.attrs = attrs # Set up the options for the 'groups' setting based on the plot's associated geometry self._set_group_options() @@ -257,9 +266,6 @@ def _read_from_H(self): """ Calculates the fatbands from a sisl hamiltonian. """ - - self.weights = [[], []] - # Define the function that will "catch" each eigenstate and # build the weights array. See BandsPlot._read_from_H to understand where # this will go exactly @@ -267,19 +273,19 @@ def _weights_from_eigenstate(eigenstate, plot, spin_index): weights = eigenstate.norm2(sum=False) - if plot.spin.has_noncolinear: + if not plot.spin.is_diagonal: # If it is a non-colinear or spin orbit calculation, we have two weights for each # orbital (one for each spin component of the state), so we just pair them together # and sum their contributions to get the weight of the orbital. weights = weights.reshape(len(weights), -1, 2).sum(2) - plot.weights[spin_index].append(weights) + return weights.real # We make bands plot read the bands, which will also populate the weights # thanks to the above step bands_read = False; err = None try: - super()._read_from_H(eigenstate_map=_weights_from_eigenstate) + super()._read_from_H(extra_vars=[{"coords": ("band", "orb"), "name": "weight", "getter": _weights_from_eigenstate}]) bands_read = True except Exception as e: # Let's keep this error, we are going to at least set the group options so that the @@ -290,24 +296,6 @@ def _weights_from_eigenstate(eigenstate, plot, spin_index): if not bands_read: raise err - # If there was only one spin component then we just take the first item in self.weights - if not self.weights[1]: - self.weights = [self.weights[0]] - - # Then we just convert the weights to a DataArray - self.weights = np.array(self.weights).real - - self.weights = DataArray( - self.weights, - coords={ - "k": self.bands.k, - "spin": np.arange(self.weights.shape[0]), - "band": np.arange(self.weights.shape[2]), - "orb": np.arange(self.weights.shape[3]), - }, - dims=("spin", "k", "band", "orb") - ) - def _set_group_options(self): # Try to find a geometry if there isn't already one @@ -398,6 +386,8 @@ def _get_group_weights(self, group, weights=None, values_storage=None, metadata_ if weights is None: weights = self.weights + if "spin" not in weights.coords: + weights = weights.expand_dims("spin") groups_param = self.get_param("groups") diff --git a/sisl/viz/plots/pdos.py b/sisl/viz/plots/pdos.py index 0815f24964..2c184cc117 100644 --- a/sisl/viz/plots/pdos.py +++ b/sisl/viz/plots/pdos.py @@ -3,19 +3,25 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. from sisl.viz.input_fields.sisl_obj import DistributionInput import numpy as np +import os import sisl from sisl.messages import warn -from sisl.physics import distribution as sisl_distribution from ..plot import Plot, entry_point from ..plotutils import find_files, random_color from ..input_fields import ( - TextInput, SileInput, SwitchInput, ColorPicker, DropdownInput, CreatableDropdown, - IntegerInput, FloatInput, RangeInput, RangeSlider, OrbitalQueries, - ProgramaticInput, Array1DInput, ListInput, GeometryInput + TextInput, SileInput, SwitchInput, ColorPicker, DropdownInput, + IntegerInput, FloatInput, OrbitalQueries, + Array1DInput, GeometryInput ) from ..input_fields.range import ErangeInput +try: + import pathos + _do_parallel_calc = True +except: + _do_parallel_calc = False + class PdosPlot(Plot): """ @@ -339,15 +345,17 @@ def _read_from_H(self, kgrid, kgrid_displ, Erange, nE, E0, distribution): # Calculate the PDOS for all available spins PDOS = [] for spin in spin_indices: - spin_PDOS = self.mp.apply.average.eigenstate( - spin=spin, - wrap=lambda eig: eig.PDOS(self.E, distribution=distribution) + with self.mp.apply(pool=_do_parallel_calc) as parallel: + spin_PDOS = parallel.average.eigenstate( + spin=spin, + wrap=lambda eig: eig.PDOS(self.E, distribution=distribution) ) PDOS.append(spin_PDOS) - if self.H.spin.is_noncolinear or self.H.spin.is_spinorbit: + if not self.H.spin.is_diagonal: PDOS = PDOS[0] + self.PDOS = np.array(PDOS) def _after_read(self, geometry): @@ -366,11 +374,6 @@ def _after_read(self, geometry): }[self.PDOS.shape[0]] self.spin = sisl.Spin(self.spin) - # Normalize the PDOS array so that we ensure a first dimension for spin even if - # there is no spin resolution - if self.PDOS.ndim == 2: - self.PDOS = np.expand_dims(self.PDOS, axis=0) - # Set the geometry. if geometry is not None: if geometry.no != self.PDOS.shape[1]: @@ -379,15 +382,18 @@ def _after_read(self, geometry): self.get_param('requests').update_options(self.geometry, self.spin) - self.PDOS = DataArray( - self.PDOS, - coords={ - 'spin': self.get_param('requests').get_options("spin"), - 'orb': range(self.PDOS.shape[1]), - 'E': self.E - }, - dims=('spin', 'orb', 'E') - ) + # If there's one dimension for spin but the calculation is spin unpolarized, + # remove the spurious spin dimension + if self.spin.is_unpolarized and self.PDOS.ndim == 3: + self.PDOS = self.PDOS[0] + + coords = {'E': self.E} + dims = ('orb', 'E') + if not self.spin.is_unpolarized: + coords['spin'] = self.get_param('requests').get_options("spin") + dims = ('spin', 'orb', 'E') + + self.PDOS = DataArray(self.PDOS, coords=coords, dims=dims) def _set_data(self, requests, E0, Erange): @@ -485,13 +491,15 @@ def query_gen(i=[-1], **kwargs): return req_PDOS = E_PDOS.sel(orb=orb) - if request['spin'] is not None: + if request['spin'] is not None and 'spin' in req_PDOS.dims: req_PDOS = req_PDOS.sel(spin=request['spin']) + reduce_coords = set(["orb", "spin"]).intersection(req_PDOS.dims) + if request["normalize"]: - req_PDOS = req_PDOS.mean(["orb", "spin"]) + req_PDOS = req_PDOS.mean(reduce_coords) else: - req_PDOS = req_PDOS.sum(["orb", "spin"]) + req_PDOS = req_PDOS.sum(reduce_coords) # Finally, multiply the values by the scale factor values = req_PDOS.values * request["scale"] @@ -534,9 +542,9 @@ def _new_request(self, **kwargs): complete_req = self.get_param("requests").complete_query - if "spin" not in kwargs and self.spin.is_noncolinear: + if "spin" not in kwargs and not self.spin.is_diagonal: if "spin" not in kwargs.get("split_on", ""): - kwargs["spin"] = ["sum"] + kwargs["spin"] = ["total"] return complete_req({"name": str(len(self.settings["requests"])), **kwargs}) diff --git a/sisl/viz/plots/tests/test_bands.py b/sisl/viz/plots/tests/test_bands.py index adc94fe583..78ec919335 100644 --- a/sisl/viz/plots/tests/test_bands.py +++ b/sisl/viz/plots/tests/test_bands.py @@ -28,7 +28,8 @@ class TestBandsPlot(_TestPlot): "gap", # Float. The value of the gap in eV "ticklabels", # Array-like with the tick labels "tickvals", # Array-like with the expected positions of the ticks - "spin_texture" # Whether spin texture should be possible to draw or not. + "spin_texture", # Whether spin texture should be possible to draw or not. + "spin", # The spin class of the calculation ] @pytest.fixture(params=BandsPlot.get_class_param("backend").options) @@ -47,11 +48,12 @@ def init_func_and_attrs(self, request, siesta_test_files): # From a siesta .bands file init_func = sisl.get_sile(siesta_test_files("SrTiO3.bands")).plot attrs = { - "bands_shape": (150, 1, 72), + "bands_shape": (150, 72), "ticklabels": ('Gamma', 'X', 'M', 'Gamma', 'R', 'X'), "tickvals": [0.0, 0.429132, 0.858265, 1.465149, 2.208428, 2.815313], "gap": 1.677, - "spin_texture": False + "spin_texture": False, + "spin": sisl.Spin("") } elif name.startswith("sisl_H"): gr = sisl.geom.graphene() @@ -60,14 +62,14 @@ def init_func_and_attrs(self, request, siesta_test_files): spin_type = name.split("_")[-1] n_spin, H = { - "unpolarized": (1, H), + "unpolarized": (0, H), "polarized": (2, H.transform(spin=sisl.Spin.POLARIZED)), - "noncolinear": (1, H.transform(spin=sisl.Spin.NONCOLINEAR)), - "spinorbit": (1, H.transform(spin=sisl.Spin.SPINORBIT)) + "noncolinear": (0, H.transform(spin=sisl.Spin.NONCOLINEAR)), + "spinorbit": (0, H.transform(spin=sisl.Spin.SPINORBIT)) }.get(spin_type) n_states = 2 - if H.spin.is_spinorbit or H.spin.is_noncolinear: + if not H.spin.is_diagonal: n_states *= 2 # Let's create the same graphene bands plot using the hamiltonian @@ -84,11 +86,12 @@ def init_func_and_attrs(self, request, siesta_test_files): init_func = bz.plot attrs = { - "bands_shape": (6, n_spin, n_states), + "bands_shape": (6, n_spin, n_states) if n_spin != 0 else (6, n_states), "ticklabels": ["Gamma", "M", "K"], "tickvals": [0., 1.70309799, 2.55464699], "gap": 0, - "spin_texture": H.spin.is_spinorbit or H.spin.is_noncolinear + "spin_texture": not H.spin.is_diagonal, + "spin": H.spin } return init_func, attrs @@ -104,8 +107,14 @@ def test_bands_dataarray(self, plot, test_attrs): # Check that it is a dataarray containing the right information bands = plot.bands assert isinstance(bands, DataArray) - assert bands.dims == ('k', 'spin', 'band') - assert bands.shape == test_attrs['bands_shape'] + + if test_attrs["spin"].is_polarized: + expected_coords = ('k', 'spin', 'band') + else: + expected_coords = ('k', 'band') + + assert set(bands.dims) == set(expected_coords) + assert bands.transpose(*expected_coords).shape == test_attrs['bands_shape'] def test_bands_in_figure(self, plot, test_attrs): @@ -178,7 +187,7 @@ def test_spin_moments(self, plot, test_attrs): # Check that it is a dataarray containing the right information spin_moments = plot.spin_moments assert isinstance(spin_moments, DataArray) - assert spin_moments.dims == ('k', 'band', 'axis') + assert set(spin_moments.dims) == set(('k', 'band', 'axis')) assert spin_moments.shape == (test_attrs['bands_shape'][0], test_attrs['bands_shape'][-1], 3) def test_spin_texture(self, plot, test_attrs): diff --git a/sisl/viz/plots/tests/test_fatbands.py b/sisl/viz/plots/tests/test_fatbands.py index 17d0df62e6..22a5c76004 100644 --- a/sisl/viz/plots/tests/test_fatbands.py +++ b/sisl/viz/plots/tests/test_fatbands.py @@ -56,13 +56,13 @@ def init_func_and_attrs(self, request, siesta_test_files): init_func = bz.plot.fatbands attrs = { - "bands_shape": (6, n_spin, n_states), - "weights_shape": (n_spin, 6, n_states, 2), + "bands_shape": (6, n_spin, n_states) if H.spin.is_polarized else (6, n_states), + "weights_shape": (n_spin, 6, n_states, 2) if H.spin.is_polarized else (6, n_states, 2), "ticklabels": ["Gamma", "M", "K"], "tickvals": [0., 1.70309799, 2.55464699], "gap": 0, - "spin_texture": H.spin.is_spinorbit or H.spin.is_noncolinear, - "soc_or_nc": H.spin.is_spinorbit or H.spin.is_noncolinear, + "spin_texture": not H.spin.is_diagonal, + "spin": H.spin } return init_func, attrs @@ -78,12 +78,24 @@ def test_weights_dataarray_avail(self, plot, test_attrs): # Check that it is a dataarray containing the right information weights = plot.weights assert isinstance(weights, DataArray) - assert weights.dims == ("spin", "k", "band", "orb") + + if test_attrs["spin"].is_polarized: + expected_dims = ("spin", "k", "band", "orb") + else: + expected_dims = ("k", "band", "orb") + assert weights.dims == expected_dims assert weights.shape == test_attrs["weights_shape"] + + def test_group_weights(self, plot): + + total_weights = plot._get_group_weights({}) + + assert isinstance(total_weights, DataArray) + assert set(total_weights.dims) == set(("spin", "band", "k")) def test_weights_values(self, plot, test_attrs): assert np.allclose(plot.weights.sum("orb"), 1), "Weight values do not sum 1 for all states." - assert np.allclose(plot.weights.sum("band"), 2 if test_attrs["soc_or_nc"] else 1) + assert np.allclose(plot.weights.sum("band"), 2 if not test_attrs["spin"].is_diagonal else 1) def test_groups(self, plot): """ diff --git a/sisl/viz/plots/tests/test_pdos.py b/sisl/viz/plots/tests/test_pdos.py index 24f6e36e7d..ec546651fb 100644 --- a/sisl/viz/plots/tests/test_pdos.py +++ b/sisl/viz/plots/tests/test_pdos.py @@ -91,6 +91,12 @@ def test_dataarray(self, plot, test_attrs): # Check if we have the correct number of orbitals assert len(PDOS.orb) == test_attrs["no"] == geom.no + + def test_request_PDOS(self, plot): + total_DOS = plot._get_request_PDOS({}) + + assert total_DOS.ndim == 1 + assert total_DOS.shape == (plot.PDOS.E.shape) def test_splitDOS(self, plot, test_attrs, inplace_split): if inplace_split: