Skip to content

Commit

Permalink
Add ability to plot only top contributing elements to SDEC plot (#1506)
Browse files Browse the repository at this point in the history
* Added nelements functionality

Added ability to specify the number of elements included in the plot. The elements are sorted by their total contribution to the absorption and emission. The top nelements are then given unique colours in the plot. All other elements are shown in silver.

* Changed other element grouping

Changed grouping for other elements. Before it would run through each 'other' element and then all the plot elements. Now all of the 'other' elements are grouped into a single column that is plotted first. Then it goes onto the plot elements. This significantly simplifies the plotting

* Fixed black formatting

* Added comments

* Fixed typos

* Black formatting

* Updated comments and total_luminosity_df

Changed comments based on reviews.

Also changed how the total luminosity df works. Instead of adding 'noint' and 'escatter' columns to the absorption df, then adding to the emission df to get the total, it now drops those columns from the emission df when doing the addition.

* Black formatting

* Update comments

Co-authored-by: Jaladh Singhal <[email protected]>

* Fixed typo

Co-authored-by: Jaladh Singhal <[email protected]>
  • Loading branch information
MarkMageeAstro and jaladh-singhal authored Mar 26, 2021
1 parent 3534d39 commit 8239d29
Showing 1 changed file with 166 additions and 22 deletions.
188 changes: 166 additions & 22 deletions tardis/visualization/tools/sdec_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def from_hdf(cls, hdf_fpath):
)

def _calculate_plotting_data(
self, packets_mode, packet_wvl_range, distance
self, packets_mode, packet_wvl_range, distance, nelements
):
"""
Calculate data to be used in plotting based on parameters passed.
Expand All @@ -442,6 +442,11 @@ def _calculate_plotting_data(
distance : astropy.Quantity
Distance used to calculate flux instead of luminosity in the plot.
It should have a length unit like m, Mpc, etc.
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to the total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
Notes
-----
Expand Down Expand Up @@ -514,15 +519,78 @@ def _calculate_plotting_data(
# Calculate luminosities to be shown in plot
(
self.emission_luminosities_df,
self.elements,
self.emission_elements,
) = self._calculate_emission_luminosities(
packets_mode=packets_mode, packet_wvl_range=packet_wvl_range
)
self.absorption_luminosities_df = (
self._calculate_absorption_luminosities(
packets_mode=packets_mode, packet_wvl_range=packet_wvl_range
)
(
self.absorption_luminosities_df,
self.absorption_elements,
) = self._calculate_absorption_luminosities(
packets_mode=packets_mode, packet_wvl_range=packet_wvl_range
)

# Calculate the total contribution of elements
# by summing absorption and emission
# Only care about elements, so drop no interaction and electron scattering
# contributions from the emitted luminosities
self.total_luminosities_df = (
self.absorption_luminosities_df
+ self.emission_luminosities_df.drop(["noint", "escatter"], axis=1)
)

# Sort the element list based on the total contribution
sorted_list = self.total_luminosities_df.sum().sort_values(
ascending=False
)

# If nelements is not included, the list of elements is just all elements
if nelements is None:
self.elements = np.array(list(self.total_luminosities_df.keys()))
else:
# If nelements is included then create a new column which is the sum
# of all other elements, i.e. those that aren't in the top contributing nelements
self.total_luminosities_df.insert(
loc=0,
column="other",
value=self.total_luminosities_df[
sorted_list.keys()[nelements:]
].sum(axis=1),
)
# Then drop all of the individual columns for elements included in 'other'
self.total_luminosities_df.drop(
sorted_list.keys()[nelements:], inplace=True, axis=1
)
# If nelements is included then create a new column which is the sum
# of all other elements, i.e. those that aren't in the top contributing nelements
self.emission_luminosities_df.insert(
loc=2,
column="other",
value=self.emission_luminosities_df[
sorted_list.keys()[nelements:]
].sum(axis=1),
)
# Then drop all of the individual columns for elements included in 'other'
self.emission_luminosities_df.drop(
sorted_list.keys()[nelements:], inplace=True, axis=1
)
# If nelements is included then create a new column which is the sum
# of all other elements, i.e. those that aren't in the top contributing nelements
self.absorption_luminosities_df.insert(
loc=2,
column="other",
value=self.absorption_luminosities_df[
sorted_list.keys()[nelements:]
].sum(axis=1),
)
# Then drop all of the individual columns for elements included in 'other'
self.absorption_luminosities_df.drop(
sorted_list.keys()[nelements:], inplace=True, axis=1
)

# Index from 1: to avoid the 'other' column
self.elements = np.sort(self.total_luminosities_df.keys()[1:])

self.photosphere_luminosity = self._calculate_photosphere_luminosity(
packets_mode=packets_mode
)
Expand Down Expand Up @@ -683,9 +751,11 @@ def _calculate_emission_luminosities(self, packets_mode, packet_wvl_range):
luminosities_df[atomic_number] = L_lambda_el.value

# Create an array of the elements with which packets interacted
elements_present = np.array(list(packets_df_grouped.groups.keys()))
emission_elements_present = np.array(
list(packets_df_grouped.groups.keys())
)

return luminosities_df, elements_present
return luminosities_df, emission_elements_present

def _calculate_absorption_luminosities(
self, packets_mode, packet_wvl_range
Expand Down Expand Up @@ -759,7 +829,11 @@ def _calculate_absorption_luminosities(

luminosities_df[atomic_number] = L_lambda_el.value

return luminosities_df
absorption_elements_present = np.array(
list(packets_df_grouped.groups.keys())
)

return luminosities_df, absorption_elements_present

def _calculate_photosphere_luminosity(self, packets_mode):
"""
Expand Down Expand Up @@ -798,6 +872,7 @@ def generate_plot_mpl(
ax=None,
figsize=(12, 7),
cmapname="jet",
nelements=None,
):
"""
Generate Spectral element DEComposition (SDEC) Plot using matplotlib.
Expand Down Expand Up @@ -826,6 +901,11 @@ def generate_plot_mpl(
cmapname : str, optional
Name of matplotlib colormap to be used for showing elements.
Default value is "jet"
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
Returns
-------
Expand All @@ -838,6 +918,7 @@ def generate_plot_mpl(
packets_mode=packets_mode,
packet_wvl_range=packet_wvl_range,
distance=distance,
nelements=nelements,
)

if ax is None:
Expand Down Expand Up @@ -917,9 +998,22 @@ def _plot_emission_mpl(self):
label="Electron Scatter Only",
)

elements_z = self.emission_luminosities_df.columns[2:].to_list()
nelements = len(elements_z)
# If the 'other' column exists then plot it as silver
if "other" in self.emission_luminosities_df.keys():
lower_level = upper_level
upper_level = (
lower_level + self.emission_luminosities_df.other.to_numpy()
)

self.ax.fill_between(
self.plot_wavelength,
lower_level,
upper_level,
color="silver",
label="Other elements",
)

elements_z = self.elements
# Contribution from each element
for i, atomic_number in enumerate(elements_z):
lower_level = upper_level
Expand All @@ -932,7 +1026,7 @@ def _plot_emission_mpl(self):
self.plot_wavelength,
lower_level,
upper_level,
color=self.cmap(i / nelements),
color=self.cmap(i / len(self.elements)),
cmap=self.cmap,
linewidth=0,
)
Expand All @@ -941,11 +1035,25 @@ def _plot_absorption_mpl(self):
"""Plot absorption part of the SDEC Plot using matplotlib."""
lower_level = np.zeros(self.absorption_luminosities_df.shape[0])

elements_z = self.absorption_luminosities_df.columns.to_list()
# To plot absorption part along -ve X-axis, we will start with
# zero upper level and keep subtracting luminosities to it (lower
# level) - fill from upper to lower level
# If the 'other' column exists then plot it as silver
if "other" in self.absorption_luminosities_df.keys():
upper_level = lower_level
lower_level = (
upper_level - self.absorption_luminosities_df.other.to_numpy()
)

self.ax.fill_between(
self.plot_wavelength,
upper_level,
lower_level,
color="silver",
)

elements_z = self.elements
for i, atomic_number in enumerate(elements_z):
# To plot absorption part along -ve X-axis, we will start with
# zero upper level and keep subtracting luminosities to it (lower
# level) - fill from upper to lower level
upper_level = lower_level
lower_level = (
upper_level
Expand Down Expand Up @@ -991,6 +1099,7 @@ def generate_plot_ply(
fig=None,
graph_height=600,
cmapname="jet",
nelements=None,
):
"""
Generate interactive Spectral element DEComposition (SDEC) Plot using plotly.
Expand Down Expand Up @@ -1019,6 +1128,11 @@ def generate_plot_ply(
cmapname : str, optional
Name of the colormap to be used for showing elements.
Default value is "jet"
nelements: int
Number of elements to include in plot. Determined by the
largest contribution to total luminosity absorbed and emitted.
Other elements are shown in silver. Default value is
None, which displays all elements
Returns
-------
Expand All @@ -1031,6 +1145,7 @@ def generate_plot_ply(
packets_mode=packets_mode,
packet_wvl_range=packet_wvl_range,
distance=distance,
nelements=nelements,
)

if fig is None:
Expand Down Expand Up @@ -1138,26 +1253,53 @@ def _plot_emission_ply(self):
)
)

elements_z = self.emission_luminosities_df.columns[2:]
nelements = len(elements_z)
# If 'other' column exists then plot as silver
if "other" in self.emission_luminosities_df.keys():
self.fig.add_trace(
go.Scatter(
x=self.emission_luminosities_df.index,
y=self.emission_luminosities_df.other,
mode="none",
name="Other elements",
fillcolor="silver",
stackgroup="emission",
)
)

elements_z = self.elements
for i, atomic_num in enumerate(elements_z):
self.fig.add_trace(
go.Scatter(
x=self.emission_luminosities_df.index,
y=self.emission_luminosities_df[atomic_num],
mode="none",
name=atomic_number2element_symbol(atomic_num),
fillcolor=self.to_rgb255_string(self.cmap(i / nelements)),
fillcolor=self.to_rgb255_string(
self.cmap(i / len(self.elements))
),
stackgroup="emission",
showlegend=False,
)
)

def _plot_absorption_ply(self):
"""Plot absorption part of the SDEC Plot using plotly."""
elements_z = self.absorption_luminosities_df.columns
nelements = len(elements_z)

# If 'other' column exists then plot as silver
if "other" in self.absorption_luminosities_df.keys():
self.fig.add_trace(
go.Scatter(
x=self.absorption_luminosities_df.index,
y=self.absorption_luminosities_df.other * -1,
mode="none",
name="Other elements",
fillcolor="silver",
stackgroup="absorption",
showlegend=False,
)
)

elements_z = self.elements

for i, atomic_num in enumerate(elements_z):
self.fig.add_trace(
Expand All @@ -1167,7 +1309,9 @@ def _plot_absorption_ply(self):
y=self.absorption_luminosities_df[atomic_num] * -1,
mode="none",
name=atomic_number2element_symbol(atomic_num),
fillcolor=self.to_rgb255_string(self.cmap(i / nelements)),
fillcolor=self.to_rgb255_string(
self.cmap(i / len(self.elements))
),
stackgroup="absorption",
showlegend=False,
)
Expand Down

0 comments on commit 8239d29

Please sign in to comment.