Skip to content

Commit

Permalink
Merge branch 'dev' into cosmology
Browse files Browse the repository at this point in the history
  • Loading branch information
grburgess committed Apr 28, 2023
2 parents db12a86 + 0410d1c commit 15e21c7
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 34 deletions.
1 change: 1 addition & 0 deletions astromodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
Quartic,
Sin,
SmoothlyBrokenPowerLaw,
DoubleSmoothlyBrokenPowerlaw,
SpatialTemplate_2D,
Standard_Rv,
StepFunction,
Expand Down
117 changes: 86 additions & 31 deletions astromodels/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

log = setup_logger(__name__)

pd.options.display.float_format = '{:.6g}'.format
pd.options.display.float_format = "{:.6g}".format


class ModelFileExists(IOError):
Expand All @@ -47,9 +47,12 @@ def __init__(self, directory, message):

free_space = disk_usage(directory).free

message += "\nFree space on the file system hosting %s was %.2f Mbytes" % (
directory,
free_space / 1024.0 / 1024.0,
message += (
"\nFree space on the file system hosting %s was %.2f Mbytes"
% (
directory,
free_space / 1024.0 / 1024.0,
)
)

super(CannotWriteModel, self).__init__(message)
Expand Down Expand Up @@ -99,11 +102,15 @@ def __init__(self, *sources):

# Dictionary to keep extended sources

self._extended_sources: Dict[str, ExtendedSource] = collections.OrderedDict()
self._extended_sources: Dict[
str, ExtendedSource
] = collections.OrderedDict()

# Dictionary to keep particle sources

self._particle_sources: Dict[str, ParticleSource] = collections.OrderedDict()
self._particle_sources: Dict[
str, ParticleSource
] = collections.OrderedDict()

# Loop over the provided sources and process them

Expand Down Expand Up @@ -210,7 +217,9 @@ def _find_properties(self, node) -> Dict[str, FunctionProperty]:
def _update_parameters(self) -> None:

self._parameters: Dict[str, Parameter] = self._find_parameters(self)
self._properties: Dict[str, FunctionProperty] = self._find_properties(self)
self._properties: Dict[str, FunctionProperty] = self._find_properties(
self
)

@property
def parameters(self) -> Dict[str, Parameter]:
Expand All @@ -237,7 +246,9 @@ def free_parameters(self) -> Dict[str, Parameter]:

# Filter selecting only free parameters

free_parameters_dictionary: Dict[str, Parameter] = collections.OrderedDict()
free_parameters_dictionary: Dict[
str, Parameter
] = collections.OrderedDict()

for parameter_name, parameter in list(self._parameters.items()):

Expand All @@ -247,6 +258,19 @@ def free_parameters(self) -> Dict[str, Parameter]:

return free_parameters_dictionary

@property
def has_free_parameters(self) -> bool:
"""
Returns True or False depending on if any parameters are free
"""

self._update_parameters()

for p in self.parameters.values():
if p.free:
return True
return False

@property
def linked_parameters(self) -> Dict[str, Parameter]:
"""
Expand All @@ -263,7 +287,9 @@ def linked_parameters(self) -> Dict[str, Parameter]:

# Filter selecting only free parameters

linked_parameter_dictionary: Dict[str, Parameter] = collections.OrderedDict()
linked_parameter_dictionary: Dict[
str, Parameter
] = collections.OrderedDict()

for parameter_name, parameter in list(self._parameters.items()):

Expand Down Expand Up @@ -363,7 +389,9 @@ def set_free_parameters(self, values: Iterable[float]) -> None:

raise AssertionError()

for parameter, this_value in zip(list(self.free_parameters.values()), values):
for parameter, this_value in zip(
list(self.free_parameters.values()), values
):

parameter.value = this_value

Expand Down Expand Up @@ -501,7 +529,9 @@ def remove_source(self, source_name: str) -> None:

self._update_parameters()

def unlink_all_from_source(self, source_name: str, warn: bool = False) -> None:
def unlink_all_from_source(
self, source_name: str, warn: bool = False
) -> None:
"""
Unlink all parameters of the current model that are linked to a parameter of a given source.
To be called before removing a source from the model.
Expand Down Expand Up @@ -686,7 +716,9 @@ def unlink(self, parameter: Parameter) -> None:

warnings.simplefilter("always", RuntimeWarning)

log.warning("Parameter %s has no link to be removed." % param.path)
log.warning(
"Parameter %s has no link to be removed." % param.path
)

def display(self, complete: bool = False) -> None:
"""
Expand Down Expand Up @@ -807,7 +839,9 @@ def _repr__base(self, rich_output=False):

fixed_parameter_dict[this_name][key] = d[key]

fixed_parameters_summary = pd.DataFrame.from_dict(fixed_parameter_dict).T
fixed_parameters_summary = pd.DataFrame.from_dict(
fixed_parameter_dict
).T

# Re-order it
fixed_parameters_summary = fixed_parameters_summary[
Expand Down Expand Up @@ -930,7 +964,9 @@ def _repr__base(self, rich_output=False):

else:

free_parameters_representation = free_parameters_summary._repr_html_()
free_parameters_representation = (
free_parameters_summary._repr_html_()
)

if len(linked_frames) == 0:

Expand Down Expand Up @@ -963,8 +999,8 @@ def _repr__base(self, rich_output=False):

for linked_function in linked_functions:

linked_function_summary_representation += linked_function.output(
rich=True
linked_function_summary_representation += (
linked_function.output(rich=True)
)
linked_function_summary_representation += new_line

Expand All @@ -987,7 +1023,9 @@ def _repr__base(self, rich_output=False):

else:

fixed_parameters_representation = fixed_parameters_summary._repr_html_()
fixed_parameters_representation = (
fixed_parameters_summary._repr_html_()
)

else:

Expand All @@ -999,7 +1037,9 @@ def _repr__base(self, rich_output=False):

else:

free_parameters_representation = free_parameters_summary.__repr__()
free_parameters_representation = (
free_parameters_summary.__repr__()
)

if properties_summary.empty:

Expand Down Expand Up @@ -1035,8 +1075,8 @@ def _repr__base(self, rich_output=False):

for linked_function in linked_functions:

linked_function_summary_representation += linked_function.output(
rich=False
linked_function_summary_representation += (
linked_function.output(rich=False)
)
linked_function_summary_representation += "%s%s" % (
new_line,
Expand Down Expand Up @@ -1065,7 +1105,9 @@ def _repr__base(self, rich_output=False):

else:

fixed_parameters_representation = fixed_parameters_summary.__repr__()
fixed_parameters_representation = (
fixed_parameters_summary.__repr__()
)

# Build the representation

Expand Down Expand Up @@ -1133,7 +1175,8 @@ def _repr__base(self, rich_output=False):
else:

representation += (
"(abridged. Use complete=True to see all fixed parameters)%s" % new_line
"(abridged. Use complete=True to see all fixed parameters)%s"
% new_line
)

representation += new_line
Expand Down Expand Up @@ -1250,7 +1293,9 @@ def to_dict_with_types(self):

elif isinstance(element, IndependentVariable):

data["%s (%s)" % (key, "IndependentVariable")] = data.pop(key)
data["%s (%s)" % (key, "IndependentVariable")] = data.pop(
key
)

elif isinstance(element, Parameter):

Expand All @@ -1262,7 +1307,9 @@ def to_dict_with_types(self):

else: # pragma: no cover

raise ModelInternalError("Found an unknown class at the top level")
raise ModelInternalError(
"Found an unknown class at the top level"
)

return data

Expand Down Expand Up @@ -1341,7 +1388,9 @@ def get_point_source_fluxes(
:return: fluxes
"""

return list(self._point_sources.values())[id](energies, tag=tag, stokes=stokes)
return list(self._point_sources.values())[id](
energies, tag=tag, stokes=stokes
)

def get_point_source_name(self, id: int) -> str:

Expand All @@ -1368,7 +1417,9 @@ def get_extended_source_fluxes(
:return: flux array
"""

return list(self._extended_sources.values())[id](j2000_ra, j2000_dec, energies)
return list(self._extended_sources.values())[id](
j2000_ra, j2000_dec, energies
)

def get_extended_source_name(self, id: int) -> str:
"""
Expand All @@ -1382,13 +1433,15 @@ def get_extended_source_name(self, id: int) -> str:

def get_extended_source_boundaries(self, id: int):

(ra_min, ra_max), (dec_min, dec_max) = list(self._extended_sources.values())[
id
].get_boundaries()
(ra_min, ra_max), (dec_min, dec_max) = list(
self._extended_sources.values()
)[id].get_boundaries()

return ra_min, ra_max, dec_min, dec_max

def is_inside_any_extended_source(self, j2000_ra: float, j2000_dec: float) -> bool:
def is_inside_any_extended_source(
self, j2000_ra: float, j2000_dec: float
) -> bool:

for ext_source in list(self.extended_sources.values()):

Expand Down Expand Up @@ -1417,7 +1470,9 @@ def get_number_of_particle_sources(self) -> int:

return len(self._particle_sources)

def get_particle_source_fluxes(self, id: int, energies: np.ndarray) -> np.ndarray:
def get_particle_source_fluxes(
self, id: int, energies: np.ndarray
) -> np.ndarray:
"""
Get the fluxes from the id-th point source
Expand Down
3 changes: 2 additions & 1 deletion astromodels/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Quartic,
Sin,
SmoothlyBrokenPowerLaw,
DoubleSmoothlyBrokenPowerlaw,
Standard_Rv,
StepFunction,
StepFunctionUpper,
Expand All @@ -43,7 +44,6 @@
from .functions_1D import Synchrotron

if has_gsl:

from .functions_1D import Cutoff_powerlaw_flux

if has_ebltable:
Expand Down Expand Up @@ -108,6 +108,7 @@
"Powerlaw_Eflux",
"Powerlaw_flux",
"SmoothlyBrokenPowerLaw",
"DoubleSmoothlyBrokenPowerlaw",
"Super_cutoff_powerlaw",
"Constant",
"Cubic",
Expand Down
2 changes: 2 additions & 0 deletions astromodels/functions/functions_1D/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Powerlaw_flux,
SmoothlyBrokenPowerLaw,
Super_cutoff_powerlaw,
DoubleSmoothlyBrokenPowerlaw,
)

if has_atomdb:
Expand All @@ -71,6 +72,7 @@
"Powerlaw_Eflux",
"Powerlaw_flux",
"SmoothlyBrokenPowerLaw",
"DoubleSmoothlyBrokenPowerlaw",
"Super_cutoff_powerlaw",
"Constant",
"Cubic",
Expand Down
4 changes: 2 additions & 2 deletions astromodels/functions/numba_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def dbl_sbpl(x, K, a1, a2, b1, xp, xb, n1, n2, xpiv):
arg2 = x / xb
arg3 = x / xj

inner1 = _pow(arg2, -a1 * n1) + _pow(arg2, -a2 * n2)
inner1 = _pow(arg2, -a1 * n1) + _pow(arg2, -a2 * n1)

inner2 = _pow(arg1, -a1 * n1) + _pow(arg1, -a2 * n2)
inner2 = _pow(arg1, -a1 * n1) + _pow(arg1, -a2 * n1)

out = _pow(xb / xpiv, a1) * _pow(
_pow(inner1, n2 / n1) + _pow(arg3, -b1 * n2) * _pow(inner2, n2 / n1),
Expand Down

0 comments on commit 15e21c7

Please sign in to comment.