Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 30, 2024
1 parent ad7e0ee commit 4f3619e
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
3 changes: 1 addition & 2 deletions app/host/SBI/sbi_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,14 @@ def sbi_missingband(obs, run_params, sbi_params, seconditer=False):
for tmax, npost in zip(
[1, run_params["tmax_per_iter"]], [1, run_params["nposterior"]]
):

signal.alarm(tmax) # max time spent on one object in sec
try:
noiseless_theta = hatp_x_y.sample(
(npost,),
x=torch.as_tensor(x.astype(np.float32)).to(device),
show_progress_bars=False,
)
print(f'success for tmax = {tmax}!')
print(f"success for tmax = {tmax}!")
except TimeoutException:
signal.alarm(0)
do_continue = True
Expand Down
5 changes: 2 additions & 3 deletions app/host/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import os
from math import pi
import time
from math import pi

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -59,7 +59,7 @@ def plot_image(image_data, figure):

image_data = np.nan_to_num(image_data, nan=perc01)
image_data = image_data + abs(np.amin(image_data)) + 0.1

scaled_image = scale_image(image_data)

figure.image(image=[scaled_image])
Expand Down Expand Up @@ -233,7 +233,6 @@ def plot_sed(transient=None, sed_results_file=None, type=""):
except AssertionError:
obs = {"filters": [], "maggies": [], "maggies_unc": []}


def maggies_to_asinh(x):
"""asinh magnitudes"""
a = 2.50 * np.log10(np.e)
Expand Down
19 changes: 11 additions & 8 deletions app/host/prospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
all_filters = [filt for filt in Filter.objects.all().select_related()]
trans_curves = [f.transmission_curve() for f in all_filters]


# add redshift scaling to agebins, such that
# t_max = t_univ
def zred_to_agebins(zred=None, **extras):
Expand Down Expand Up @@ -91,11 +92,15 @@ def build_obs(transient, aperture_type, use_mag_offset=True):
"""

photometry = AperturePhotometry.objects.filter(
transient=transient, aperture__type__exact=aperture_type
).filter(Q(is_validated="true") | Q(is_validated="contamination warning")).prefetch_related()
filter_names = photometry.values_list('filter__name',flat=True)

photometry = (
AperturePhotometry.objects.filter(
transient=transient, aperture__type__exact=aperture_type
)
.filter(Q(is_validated="true") | Q(is_validated="contamination warning"))
.prefetch_related()
)
filter_names = photometry.values_list("filter__name", flat=True)

if not photometry.exists():
raise ValueError(f"No host photometry of type {aperture_type}")

Expand All @@ -111,7 +116,7 @@ def build_obs(transient, aperture_type, use_mag_offset=True):

filters, flux_maggies, flux_maggies_error = [], [], []

for filter,trans_curve in zip(all_filters,trans_curves):
for filter, trans_curve in zip(all_filters, trans_curves):
try:
if filter.name in filter_names:
datapoint = photometry.get(filter=filter)
Expand Down Expand Up @@ -167,7 +172,6 @@ def build_obs(transient, aperture_type, use_mag_offset=True):
mJy_to_maggies(fluxerr_mwcorr * 10 ** (-0.4 * mag_offset))
)


obs_data = dict(
wavelength=None,
spectrum=None,
Expand All @@ -178,7 +182,6 @@ def build_obs(transient, aperture_type, use_mag_offset=True):
filters=filters,
)


return fix_obs(obs_data)


Expand Down
20 changes: 13 additions & 7 deletions app/host/transient_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def _prerequisites(self):
"""
Only prerequisite is that the host match task is not processed.
"""
return {"Host match": "not processed",
"Cutout download": "processed",
"Transient MWEBV": "processed"}
return {
"Host match": "not processed",
"Cutout download": "processed",
"Transient MWEBV": "processed",
}

@property
def task_name(self):
Expand Down Expand Up @@ -94,8 +96,10 @@ def _prerequisites(self):
"""
Only prerequisite is that the transient MWEBV task is not processed.
"""
return {"Transient MWEBV": "not processed",
"Transient information": "processed"}
return {
"Transient MWEBV": "not processed",
"Transient information": "processed",
}

@property
def task_name(self):
Expand Down Expand Up @@ -615,8 +619,10 @@ class TransientInformation(TransientTaskRunner):
"""Task Runner to gather information about the Transient"""

def _prerequisites(self):
return {"Transient information": "not processed",
"Cutout download": "processed"}
return {
"Transient information": "not processed",
"Cutout download": "processed",
}

@property
def task_name(self):
Expand Down
21 changes: 13 additions & 8 deletions app/host/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import django_filters
from django.contrib.auth.decorators import login_required
from django.contrib.auth.decorators import user_passes_test
Expand Down Expand Up @@ -34,7 +36,7 @@
from host.transient_name_server import get_transients_from_tns_by_name
from revproxy.views import ProxyView
from silk.profiling.profiler import silk_profile
import time


def filter_transient_categories(qs, value, task_register=None):
if task_register is None:
Expand Down Expand Up @@ -84,9 +86,9 @@ def filter_transient_categories(qs, value, task_register=None):
elif value == "Finished Transients":
qs = qs.filter(
~Q(
pk__in=task_register.filter(
~Q(status__message="processed")
).values("transient")
pk__in=task_register.filter(~Q(status__message="processed")).values(
"transient"
)
)
)

Expand Down Expand Up @@ -218,7 +220,6 @@ def analytics(request):


def results(request, slug):

transients = Transient.objects.all()
transient = transients.get(name__exact=slug)

Expand Down Expand Up @@ -416,8 +417,10 @@ def acknowledgements(request):
def home(request):
analytics_results = {}

task_register_qs = TaskRegister.objects.filter(status__message='processed').prefetch_related()
for aggregate, qs_value in zip(
task_register_qs = TaskRegister.objects.filter(
status__message="processed"
).prefetch_related()
for aggregate, qs_value in zip(
[
"Basic Information",
"Host Identification",
Expand All @@ -432,7 +435,9 @@ def home(request):
],
):
analytics_results[aggregate] = len(
filter_transient_categories(Transient.objects.all(), qs_value, task_register=task_register_qs)
filter_transient_categories(
Transient.objects.all(), qs_value, task_register=task_register_qs
)
)

# transients = TaskRegisterSnapshot.objects.filter(
Expand Down

0 comments on commit 4f3619e

Please sign in to comment.