Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff tardis/visualization #2846

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tardis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tardis.simulation import Simulation
from tardis.tests.fixtures.atom_data import *
from tardis.tests.fixtures.regression_data import regression_data
from tardis.tests.test_util import monkeysession

# ensuring that regression_data is not removed by ruff
assert regression_data is not None
Expand Down
15 changes: 7 additions & 8 deletions tardis/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Visualization tools and widgets for TARDIS."""

from tardis.visualization.tools.convergence_plot import ConvergencePlots

from tardis.visualization.tools.liv_plot import LIVPlotter
from tardis.visualization.tools.rpacket_plot import RPacketPlotter
from tardis.visualization.tools.sdec_plot import SDECPlotter
from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget
from tardis.visualization.widgets.grotrian import GrotrianWidget
from tardis.visualization.widgets.line_info import LineInfoWidget
from tardis.visualization.widgets.shell_info import (
shell_info_from_simulation,
shell_info_from_hdf,
shell_info_from_simulation,
)
from tardis.visualization.widgets.line_info import LineInfoWidget
from tardis.visualization.widgets.grotrian import GrotrianWidget
from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget
from tardis.visualization.tools.sdec_plot import SDECPlotter
from tardis.visualization.tools.rpacket_plot import RPacketPlotter
from tardis.visualization.tools.liv_plot import LIVPlotter
1 change: 1 addition & 0 deletions tardis/visualization/plot_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions to be used in plotting."""

import re

import numpy as np


Expand Down
13 changes: 6 additions & 7 deletions tardis/visualization/tools/convergence_plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Convergence Plots to see the convergence of the simulation in real time."""

from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.colors as clr
from contextlib import suppress

import ipywidgets as widgets
import matplotlib as mpl
import numpy as np
import plotly.graph_objects as go
from astropy import units as u
from IPython.display import display
import matplotlib as mpl
import ipywidgets as widgets
from contextlib import suppress
from traitlets import TraitError
from astropy import units as u


def transition_colors(length, name="jet"):
Expand All @@ -36,7 +35,7 @@ def transition_colors(length, name="jet"):
return colors


class ConvergencePlots(object):
class ConvergencePlots:
"""
Create and update convergence plots for visualizing convergence of the simulation.

Expand Down
10 changes: 4 additions & 6 deletions tardis/visualization/tools/liv_plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging

import astropy.units as u
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import astropy.units as u
import plotly.graph_objects as go

import tardis.visualization.tools.sdec_plot as sdec
from tardis.util.base import (
atomic_number2element_symbol,
int_to_roman,
)
import tardis.visualization.tools.sdec_plot as sdec
from tardis.visualization import plot_util as pu

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +37,6 @@ def __init__(self, data, time_explosion, velocity):
velocity : astropy.units.Quantity
Velocity array from the simulation.
"""

self.data = data
self.time_explosion = time_explosion
self.velocity = velocity
Expand All @@ -57,7 +56,6 @@ def from_simulation(cls, sim):
-------
LIVPlotter
"""

return cls(
dict(
virtual=sdec.SDECData.from_simulation(sim, "virtual"),
Expand Down
12 changes: 4 additions & 8 deletions tardis/visualization/tools/rpacket_plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import math
import logging
import pandas as pd
import numpy as np
import math

import astropy.units as u
import plotly.express as px
import numpy as np
import plotly.graph_objects as go


Expand Down Expand Up @@ -107,7 +106,7 @@ def from_simulation(cls, sim, no_of_packets=15):
return cls(sim, sim.last_no_of_packets)
else:
raise AttributeError(
""" There is no attribute named rpacket_tracker in the simulation object passed. Try enabling the
""" There is no attribute named rpacket_tracker in the simulation object passed. Try enabling the
rpacket tracking in the configuration. To enable rpacket tracking see: https://tardis-sn.github.io/tardis/io/output/rpacket_tracking.html#How-to-Setup-the-Tracking-for-the-RPackets?"""
)

Expand All @@ -125,7 +124,6 @@ def generate_plot(self, theme="light"):
plotly.graph_objs._figure.Figure
plot containing the packets, photosphere and the shells.
"""

self.fig = go.Figure()

# getting velocity of different shells
Expand Down Expand Up @@ -491,7 +489,6 @@ def get_coordinates_multiple_packets(self, r_packet_tracker):
numpy.ndarray
array of array containing x coordinates, y coordinates and the interactions for multiple packets
"""

# for plotting packets at equal intervals throught the circle, we choose thetas distributed uniformly
thetas = np.linspace(0, 2 * math.pi, self.no_of_packets + 1)
rpackets_x = []
Expand Down Expand Up @@ -656,7 +653,6 @@ def get_slider_steps(self, rpacket_max_array_size):
list
list of dictionaries of different steps for different frames.
"""

slider_steps = []
for step_no in range(rpacket_max_array_size):
slider_steps.append(
Expand Down
16 changes: 8 additions & 8 deletions tardis/visualization/tools/tests/test_convergence_plot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""Tests for Convergence Plots."""

from collections import defaultdict
from copy import deepcopy

import plotly.graph_objects as go
import pytest
from tardis.tests.test_util import monkeysession
from astropy import units as u

from tardis import run_tardis
from tardis.visualization.tools.convergence_plot import (
ConvergencePlots,
transition_colors,
)
from collections import defaultdict
import plotly.graph_objects as go
from astropy import units as u


@pytest.fixture(scope="module", params=[0, 1, 2])
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_update_t_inner_luminosities_plot(convergence_plots):
# check number of traces
assert len(convergence_plots.t_inner_luminosities_plot.data) == 5

for index in range(0, 5):
for index in range(5):
# check x and y values for all traces
assert (
len(convergence_plots.t_inner_luminosities_plot.data[index].x)
Expand All @@ -108,7 +108,7 @@ def test_update_plasma_plots(convergence_plots):
"""Test the state of plasma plots after updating."""
n_iterations = convergence_plots.iterations
expected_n_traces = 2 * n_iterations + 2
velocity = range(0, n_iterations) * u.m / u.s
velocity = range(n_iterations) * u.m / u.s

convergence_plots.fetch_data(
name="velocity", value=velocity, item_type="iterable"
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_override_plot_parameters(convergence_plots):
convergence_plots.t_inner_luminosities_plot["layout"]["xaxis2"][
"showgrid"
]
== False
is False
)

# testing plot parameters for plasma plot
Expand All @@ -212,7 +212,7 @@ def test_override_plot_parameters(convergence_plots):
)
# checking layout for plasma plot
assert (
convergence_plots.plasma_plot["layout"]["xaxis2"]["showgrid"] == False
convergence_plots.plasma_plot["layout"]["xaxis2"]["showgrid"] is False
)


Expand Down
4 changes: 2 additions & 2 deletions tardis/visualization/tools/tests/test_liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import astropy.units as u
import numpy as np
import pytest
from matplotlib.testing.compare import compare_images
from matplotlib.collections import PolyCollection
from matplotlib.lines import Line2D
from matplotlib.testing.compare import compare_images

from tardis.base import run_tardis
from tardis.io.util import HDFWriterMixin
from tardis.visualization.tools.liv_plot import LIVPlotter
from tardis.tests.fixtures.regression_data import RegressionData
from tardis.visualization.tools.liv_plot import LIVPlotter


class PlotDataHDF(HDFWriterMixin):
Expand Down
4 changes: 3 additions & 1 deletion tardis/visualization/tools/tests/test_rpacket_plot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Tests for RPacketPlotter Plots"""

import math

import astropy.units as u
import pytest
import numpy as np
import numpy.testing as npt
import pytest

from tardis.visualization import RPacketPlotter


Expand Down
2 changes: 1 addition & 1 deletion tardis/visualization/tools/tests/test_sdec_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import pandas as pd
import pytest
from matplotlib.testing.compare import compare_images
from matplotlib.collections import PolyCollection
from matplotlib.lines import Line2D
from matplotlib.testing.compare import compare_images

from tardis.base import run_tardis
from tardis.io.util import HDFWriterMixin
Expand Down
50 changes: 25 additions & 25 deletions tardis/visualization/widgets/custom_abundance.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
"""Class to create and display Custom Abundance Widget."""
import os
import yaml
from pathlib import Path

import ipywidgets as ipw
import numpy as np
import pandas as pd
import ipywidgets as ipw
import plotly.graph_objects as go
import yaml
from astropy import units as u
from radioactivedecay import Nuclide
from radioactivedecay.utils import Z_DICT, elem_to_Z
from pathlib import Path

import tardis
from tardis.io.model.readers.generic_readers import read_uniform_mass_fractions
from tardis.util.base import (
quantity_linspace,
is_valid_nuclide_or_elem,
is_notebook,
)
from tardis.io.atom_data.base import AtomData
from tardis.io.configuration.config_reader import Configuration
from tardis.model import SimulationState
from tardis.io.configuration.config_validator import validate_dict
from tardis.io.model.parse_density_configuration import (
calculate_power_law_density,
calculate_exponential_density,
calculate_power_law_density,
)
from tardis.io.atom_data.base import AtomData
from tardis.io.configuration.config_validator import validate_dict
from tardis.io.model.readers.csvy import load_csvy
from tardis.io.model.readers.csvy import (
load_csvy,
parse_csv_mass_fractions,
)
from tardis.util.base import atomic_number2element_symbol, quantity_linspace
from tardis.io.model.readers.generic_readers import read_uniform_mass_fractions
from tardis.model import SimulationState
from tardis.util.base import (
atomic_number2element_symbol,
is_notebook,
is_valid_nuclide_or_elem,
quantity_linspace,
)
from tardis.visualization.tools.convergence_plot import transition_colors
from tardis.visualization.widgets.util import debounce

Expand Down Expand Up @@ -601,7 +602,7 @@ def update_input_item_value(self, index, value):
"""
self._trigger = False
# `input_items` is the list of abundance input widgets.
self.input_items[index].value = float("{:.2e}".format(value))
self.input_items[index].value = float(f"{value:.2e}")
self._trigger = True

def read_abundance(self):
Expand Down Expand Up @@ -734,12 +735,11 @@ def overwrite_existing_shells(self, v_0, v_1):
else position_1
)

if (index_1 - index_0 > 1) or (
index_1 - index_0 == 1 and not np.isclose(v_vals[index_0], v_0)
):
return True
else:
return False
return bool(
index_1 - index_0 > 1
or index_1 - index_0 == 1
and not np.isclose(v_vals[index_0], v_0)
)

def on_btn_add_shell(self, obj):
"""Add new shell with given boundary velocities. Triggered if
Expand Down Expand Up @@ -914,7 +914,7 @@ def check_eventhandler(self, obj):
"""
item_index = obj.owner.index

if obj.new == True:
if obj.new is True:
self.bound_locked_sum_to_1(item_index)

def dpd_shell_no_eventhandler(self, obj):
Expand Down Expand Up @@ -1470,7 +1470,7 @@ def write_csv_portion(self, path):
data = data.sort_index()

formatted_v = pd.Series(self.data.velocity.value).apply(
lambda x: "%.3e" % x
lambda x: f"{x:.3e}"
)
# Make sure velocity is within the boundary.
formatted_v[0] = self.data.velocity.value[0]
Expand Down Expand Up @@ -1685,7 +1685,7 @@ def read_density(self):
"""
dvalue = self.data.density[self.shell_no].value
self._trigger = False
self.input_d.value = float("{:.3e}".format(dvalue))
self.input_d.value = float(f"{dvalue:.3e}")
self._trigger = True

def update_density_plot(self):
Expand Down
Loading
Loading