Skip to content

Commit

Permalink
Ruff tardis/spectrum (#2848)
Browse files Browse the repository at this point in the history
* Ruff on spectrum safe fixes

* black on tardis/spectru
  • Loading branch information
atharva-2001 authored Oct 21, 2024
1 parent 6ae9446 commit 03d970d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 44 deletions.
39 changes: 18 additions & 21 deletions tardis/spectrum/formal_integral.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import warnings

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.sparse.linalg as linalg
from scipy.interpolate import interp1d
from astropy import units as u
from tardis import constants as const
from numba import njit, char, float64, int64, typeof, byte, prange
from numba.experimental import jitclass

from numba import njit, prange
from scipy.interpolate import interp1d

from tardis import constants as const
from tardis.opacities.opacity_state import (
OpacityState,
opacity_state_initialize,
)
from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo import njit_dict, njit_dict_no_parallel
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
OpacityState,
)
from tardis.spectrum.formal_integral_cuda import (
CudaFormalIntegrator,
)

from tardis.spectrum.spectrum import TARDISSpectrum
from tardis.transport.montecarlo import njit_dict, njit_dict_no_parallel
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
)

C_INV = 3.33564e-11
M_PI = np.arccos(-1)
Expand Down Expand Up @@ -61,8 +57,7 @@ def numba_formal_integral(
intensities at each p-ray multiplied by p
frequency x p-ray grid
"""

# todo: add all the original todos
# TODO: add all the original todos
# Initialize the output which is shared among threads
L = np.zeros(inu_size, dtype=np.float64)
# global read-only values
Expand Down Expand Up @@ -215,7 +210,7 @@ def numba_formal_integral(


# @jitclass(integrator_spec)
class NumbaFormalIntegrator(object):
class NumbaFormalIntegrator:
"""
Helper class for performing the formal integral
with numba.
Expand Down Expand Up @@ -258,7 +253,7 @@ def formal_integral(
)


class FormalIntegrator(object):
class FormalIntegrator:
"""
Class containing the formal integrator.
Expand Down Expand Up @@ -297,7 +292,8 @@ def __init__(self, simulation_state, plasma, transport, points=1000):

def generate_numba_objects(self):
"""instantiate the numba interface objects
needed for computing the formal integral"""
needed for computing the formal integral
"""
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry

self.numba_radial_1d_geometry = NumbaRadial1DGeometry(
Expand Down Expand Up @@ -353,7 +349,7 @@ def raise_or_return(message):
"FormalIntegrator."
)

if not self.transport.line_interaction_type in [
if self.transport.line_interaction_type not in [
"downbranch",
"macroatom",
]:
Expand Down Expand Up @@ -613,7 +609,8 @@ def interpolate_integrator_quantities(

def formal_integral(self, nu, N):
"""Do the formal integral with the numba
routines"""
routines
"""
# TODO: get rid of storage later on

res = self.make_source_function()
Expand Down
15 changes: 4 additions & 11 deletions tardis/spectrum/formal_integral_cuda.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import sys
import numpy as np
from astropy import units as u
from numba import float64, int64, cuda
import math

import numpy as np
from numba import cuda

from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON

C_INV = 3.33564e-11
Expand All @@ -27,7 +26,6 @@ def cuda_vector_integrator(L, I_nu, N, R_max):
R_max : float64
"""

nu_idx = cuda.grid(1)
L[nu_idx] = (
8 * M_PI * M_PI * trapezoid_integration_cuda(I_nu[nu_idx], R_max / N)
Expand Down Expand Up @@ -101,7 +99,6 @@ def cuda_formal_integral(
shell_id : array(int64, 2d, C)
List of shells for each thread
"""

# global read-only values
size_line, size_shell = tau_sobolev.shape
R_ph = r_inner[0] # make sure these are cgs
Expand Down Expand Up @@ -234,7 +231,7 @@ def cuda_formal_integral(
I_nu_thread[p_idx] *= p


class CudaFormalIntegrator(object):
class CudaFormalIntegrator:
"""
Helper class for performing the formal integral
with CUDA.
Expand Down Expand Up @@ -430,8 +427,6 @@ class BoundsError(IndexError):
binary search
"""

pass


@cuda.jit(device=True)
def line_search_cuda(nu, nu_insert, number_of_lines):
Expand Down Expand Up @@ -519,7 +514,6 @@ def trapezoid_integration_cuda(arr, dx):
arr : (array(float64, 1d, C)
dx : np.float64
"""

result = arr[0] + arr[-1]

for x in range(1, len(arr) - 1):
Expand Down Expand Up @@ -564,5 +558,4 @@ def calculate_p_values(R_max, N):
-------
float64
"""

return np.arange(N).astype(np.float64) * R_max / (N - 1)
3 changes: 0 additions & 3 deletions tardis/spectrum/tests/test_cuda_formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
NumbaFormalIntegrator,
)

from tardis.transport.montecarlo.base import MonteCarloTransportSolver


# Test cases must also take into account use of a GPU to run. If there is no GPU then the test cases will fail.
GPUs_available = cuda.is_available()

Expand Down
8 changes: 3 additions & 5 deletions tardis/spectrum/tests/test_numba_formal_integral.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pytest
import numpy as np
from tardis import constants as c

from copy import deepcopy
import numpy.testing as ntest
import pytest

from tardis.util.base import intensity_black_body
import tardis.spectrum.formal_integral as formal_integral
from tardis import constants as c
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry
from tardis.util.base import intensity_black_body


@pytest.mark.parametrize(
Expand Down
10 changes: 6 additions & 4 deletions tardis/spectrum/tests/test_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import os

import astropy.tests.helper as test_helper
import numpy as np
import pandas as pd
import os
import pytest
from astropy import units as u
from tardis import constants as c
import astropy.tests.helper as test_helper
from numpy.testing import assert_almost_equal

from tardis import constants as c
from tardis.spectrum.spectrum import (
TARDISSpectrum,
)
Expand Down

0 comments on commit 03d970d

Please sign in to comment.