Skip to content

Commit

Permalink
adding a logic for raman spectrum plot in model and widget
Browse files Browse the repository at this point in the history
  • Loading branch information
AndresOrtegaGuerrero committed Dec 1, 2024
1 parent 3321f58 commit eaa337c
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 312 deletions.
6 changes: 4 additions & 2 deletions src/aiidalab_qe_vibroscopy/app/result/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import traitlets as tl


from aiidalab_qe_vibroscopy.utils.raman.result import export_iramanworkchain_data
from aiidalab_qe_vibroscopy.utils.phonons.result import export_phononworkchain_data
from aiidalab_qe_vibroscopy.utils.euphonic import export_euphonic_data

Expand All @@ -24,7 +23,10 @@ def needs_dielectric_tab(self):
return True

def needs_raman_tab(self):
return export_iramanworkchain_data(self.get_vibro_node())
node = self.get_vibro_node()
if not any(key in node for key in ["iraman", "harmonic"]):
return False
return True

# Here we use _fetch_child_process_node() since the function needs the input_structure in inputs
def needs_phonons_tab(self):
Expand Down
13 changes: 11 additions & 2 deletions src/aiidalab_qe_vibroscopy/app/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from aiidalab_qe_vibroscopy.app.widgets.dielectricmodel import DielectricModel
import ipywidgets as ipw

from aiidalab_qe_vibroscopy.app.widgets.ramanwidget import RamanWidget
from aiidalab_qe_vibroscopy.app.widgets.ramanmodel import RamanModel


class VibroResultsPanel(ResultsPanel[VibroResultsModel]):
title = "Vibronic"
Expand All @@ -32,8 +35,14 @@ def render(self):
if self._model.needs_phonons_tab():
tab_data.append(("Phonons", ipw.HTML("phonon_data")))

if self._model.needs_raman_tab():
tab_data.append(("Raman", ipw.HTML("raman_data")))
needs_raman_tab = self._model.needs_raman_tab()
if needs_raman_tab:
raman_model = RamanModel()
raman_widget = RamanWidget(
model=raman_model,
node=vibro_node,
)
tab_data.append(("Raman", raman_widget))

needs_dielectri_tab = self._model.needs_dielectric_tab()

Expand Down
307 changes: 307 additions & 0 deletions src/aiidalab_qe_vibroscopy/app/widgets/ramanmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
from __future__ import annotations
from aiidalab_qe.common.mvc import Model
import traitlets as tl
from aiida.common.extendeddicts import AttributeDict
from IPython.display import display
import numpy as np
from aiida_vibroscopy.utils.broadenings import multilorentz
import plotly.graph_objects as go
import base64
import json


class RamanModel(Model):
vibro = tl.Instance(AttributeDict, allow_none=True)

raman_plot_type_options = tl.List(
trait=tl.List(tl.Unicode()),
default_value=[
("Powder", "powder"),
("Single Crystal", "single_crystal"),
],
)
raman_plot_type = tl.Unicode("powder")
raman_temperature = tl.Float(300)
raman_frequency_laser = tl.Float(532)
raman_pol_incoming = tl.Unicode("0 0 1")
raman_pol_outgoing = tl.Unicode("0 0 1")
raman_broadening = tl.Float(10.0)
raman_separate_polarizations = tl.Bool(False)

frequencies = []
intensities = []

frequencies_depolarized = []
intensities_depolarized = []

def fetch_data(self):
"""Fetch the Raman data from the VibroWorkChain"""
self.raman_data = self.get_vibrational_data(self.vibro)

def update_data(self):
"""
Update the Raman plot data based on the selected plot type and configuration.
"""
if self.raman_plot_type == "powder":
self._update_powder_data()
else:
self._update_single_crystal_data()

def _update_powder_data(self):
"""
Update data for the powder Raman plot.
"""
(
polarized_intensities,
depolarized_intensities,
frequencies,
_,
) = self.raman_data.run_powder_raman_intensities(
frequencies=self.raman_frequency_laser,
temperature=self.raman_temperature,
)

if self.raman_separate_polarizations:
self.frequencies, self.intensities = self.generate_plot_data(
frequencies,
polarized_intensities,
self.raman_broadening,
)
self.frequencies_depolarized, self.intensities_depolarized = (
self.generate_plot_data(
frequencies,
depolarized_intensities,
self.raman_broadening,
)
)
else:
combined_intensities = polarized_intensities + depolarized_intensities
self.frequencies, self.intensities = self.generate_plot_data(
frequencies,
combined_intensities,
self.raman_broadening,
)
self.frequencies_depolarized, self.intensities_depolarized = [], []

def _update_single_crystal_data(self):
"""
Update data for the single crystal Raman plot.
"""
dir_incoming, _ = self._check_inputs_correct(self.raman_pol_incoming)
dir_outgoing, _ = self._check_inputs_correct(self.raman_pol_outgoing)

(
intensities,
frequencies,
labels,
) = self.raman_data.run_single_crystal_raman_intensities(
pol_incoming=dir_incoming,
pol_outgoing=dir_outgoing,
frequencies=self.raman_frequency_laser,
temperature=self.raman_temperature,
)
self.frequencies, self.intensities = self.generate_plot_data(
frequencies, intensities
)
self.frequencies_depolarized, self.intensities_depolarized = [], []

def update_plot(self, plot):
"""
Update the Raman plot based on the selected plot type and configuration.
Parameters:
plot: The plotly.graph_objs.Figure widget to update.
"""
update_function = (
self._update_powder_plot
if self.raman_plot_type == "powder"
else self._update_single_crystal_plot
)
update_function(plot)

def _update_powder_plot(self, plot):
"""
Update the powder Raman plot.
Parameters:
plot: The plotly.graph_objs.Figure widget to update.
"""
if self.raman_separate_polarizations:
self._update_polarized_and_depolarized(plot)
else:
self._clear_depolarized_and_update(plot)

def _update_polarized_and_depolarized(self, plot):
"""
Update the plot when polarized and depolarized data are separate.
Parameters:
plot: The plotly.graph_objs.Figure widget to update.
"""
if len(plot.data) == 1:
self._update_trace(
plot.data[0], self.frequencies, self.intensities, "Polarized"
)
plot.add_trace(
go.Scatter(
x=self.frequencies_depolarized,
y=self.intensities_depolarized,
name="Depolarized",
)
)
plot.layout.title.text = "Powder Raman Spectrum"
elif len(plot.data) == 2:
self._update_trace(
plot.data[0], self.frequencies, self.intensities, "Polarized"
)
self._update_trace(
plot.data[1],
self.frequencies_depolarized,
self.intensities_depolarized,
"Depolarized",
)
plot.layout.title.text = "Powder Raman Spectrum"

def _clear_depolarized_and_update(self, plot):
"""
Clear depolarized data and update the plot.
Parameters:
plot: The plotly.graph_objs.Figure widget to update.
"""
if len(plot.data) == 2:
self._update_trace(plot.data[0], self.frequencies, self.intensities, "")
plot.data[1].x = []
plot.data[1].y = []
plot.layout.title.text = "Powder Raman Spectrum"
elif len(plot.data) == 1:
self._update_trace(plot.data[0], self.frequencies, self.intensities, "")
plot.layout.title.text = "Powder Raman Spectrum"

def _update_single_crystal_plot(self, plot):
"""
Update the single crystal Raman plot.
Parameters:
plot: The plotly.graph_objs.Figure widget to update.
"""
if len(plot.data) == 2:
self._update_trace(plot.data[0], self.frequencies, self.intensities, "")
plot.data[1].x = []
plot.data[1].y = []
plot.layout.title.text = "Single Crystal Raman Spectrum"
elif len(plot.data) == 1:
self._update_trace(plot.data[0], self.frequencies, self.intensities, "")
plot.layout.title.text = "Single Crystal Raman Spectrum"

def _update_trace(self, trace, x_data, y_data, name):
"""
Helper function to update a single trace in the plot.
Parameters:
trace: The trace to update.
x_data: The new x-axis data.
y_data: The new y-axis data.
name: The name of the trace.
"""
trace.x = x_data
trace.y = y_data
trace.name = name

def get_vibrational_data(self, node):
"""
Extract vibrational data from an IRamanWorkChain or HarmonicWorkChain node.
Parameters:
node: The workchain node containing IRaman or Harmonic data.
Returns:
The vibrational accuracy data (vibro) or None if not available.
"""
# Determine the output node
output_node = getattr(node, "iraman", None) or getattr(node, "harmonic", None)
if not output_node:
return None

# Check for vibrational data and extract accuracy
vibrational_data = getattr(output_node, "vibrational_data", None)
if not vibrational_data:
return None

# Extract vibrational accuracy (prefer numerical_accuracy_4 if available)
vibro = getattr(vibrational_data, "numerical_accuracy_4", None) or getattr(
vibrational_data, "numerical_accuracy_2", None
)

return vibro

def _check_inputs_correct(self, polarization):
# Check if the polarization vectors are correct
input_text = polarization
input_values = input_text.split()
dir_values = []
if len(input_values) == 3:
try:
dir_values = [float(i) for i in input_values]
return dir_values, True
except: # noqa: E722
return dir_values, False
else:
return dir_values, False

def generate_plot_data(
self,
frequencies: list[float],
intensities: list[float],
broadening: float = 10.0,
x_range: list[float] | str = "auto",
broadening_function=multilorentz,
normalize: bool = True,
):
frequencies = np.array(frequencies)
intensities = np.array(intensities)

if x_range == "auto":
xi = max(0, frequencies.min() - 200)
xf = frequencies.max() + 200
x_range = np.arange(xi, xf, 1.0)

y_range = broadening_function(x_range, frequencies, intensities, broadening)

if normalize:
y_range /= y_range.max()

return x_range, y_range

def download_data(self, _=None):
filename = "spectra.json"
if self.raman_separate_polarizations:
my_dict = {
"Frequencies cm-1": self.frequencies.tolist(),
"Polarized intensities": self.intensities.tolist(),
"Depolarized intensities": self.intensities_depolarized.tolist(),
}
else:
my_dict = {
"Frequencies cm-1": self.frequencies.tolist(),
"Intensities": self.intensities.tolist(),
}
json_str = json.dumps(my_dict)
b64_str = base64.b64encode(json_str.encode()).decode()
self._download(payload=b64_str, filename=filename)

@staticmethod
def _download(payload, filename):
from IPython.display import Javascript

javas = Javascript(
"""
var link = document.createElement('a');
link.href = 'data:text/json;charset=utf-8;base64,{payload}'
link.download = "{filename}"
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
""".format(payload=payload, filename=filename)
)
display(javas)
Loading

0 comments on commit eaa337c

Please sign in to comment.