From 4f3619edad90871afd3e84b17e0244592aa1101b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 02:37:13 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- app/host/SBI/sbi_pp.py | 3 +-- app/host/plotting_utils.py | 5 ++--- app/host/prospector.py | 19 +++++++++++-------- app/host/transient_tasks.py | 20 +++++++++++++------- app/host/views.py | 21 +++++++++++++-------- 5 files changed, 40 insertions(+), 28 deletions(-) diff --git a/app/host/SBI/sbi_pp.py b/app/host/SBI/sbi_pp.py index dce3f13f..df71bb95 100644 --- a/app/host/SBI/sbi_pp.py +++ b/app/host/SBI/sbi_pp.py @@ -311,7 +311,6 @@ def sbi_missingband(obs, run_params, sbi_params, seconditer=False): for tmax, npost in zip( [1, run_params["tmax_per_iter"]], [1, run_params["nposterior"]] ): - signal.alarm(tmax) # max time spent on one object in sec try: noiseless_theta = hatp_x_y.sample( @@ -319,7 +318,7 @@ def sbi_missingband(obs, run_params, sbi_params, seconditer=False): x=torch.as_tensor(x.astype(np.float32)).to(device), show_progress_bars=False, ) - print(f'success for tmax = {tmax}!') + print(f"success for tmax = {tmax}!") except TimeoutException: signal.alarm(0) do_continue = True diff --git a/app/host/plotting_utils.py b/app/host/plotting_utils.py index 0899c5e2..2afa8bc1 100644 --- a/app/host/plotting_utils.py +++ b/app/host/plotting_utils.py @@ -1,7 +1,7 @@ import math import os -from math import pi import time +from math import pi import numpy as np import pandas as pd @@ -59,7 +59,7 @@ def plot_image(image_data, figure): image_data = np.nan_to_num(image_data, nan=perc01) image_data = image_data + abs(np.amin(image_data)) + 0.1 - + scaled_image = scale_image(image_data) figure.image(image=[scaled_image]) @@ -233,7 +233,6 @@ def plot_sed(transient=None, sed_results_file=None, type=""): except AssertionError: obs = {"filters": [], "maggies": [], "maggies_unc": []} - def maggies_to_asinh(x): """asinh magnitudes""" a = 2.50 * np.log10(np.e) diff --git a/app/host/prospector.py b/app/host/prospector.py index 87cf8bc6..e163d660 100644 --- a/app/host/prospector.py +++ b/app/host/prospector.py @@ -38,6 +38,7 @@ all_filters = [filt for filt in Filter.objects.all().select_related()] trans_curves = [f.transmission_curve() for f in all_filters] + # add redshift scaling to agebins, such that # t_max = t_univ def zred_to_agebins(zred=None, **extras): @@ -91,11 +92,15 @@ def build_obs(transient, aperture_type, use_mag_offset=True): """ - photometry = AperturePhotometry.objects.filter( - transient=transient, aperture__type__exact=aperture_type - ).filter(Q(is_validated="true") | Q(is_validated="contamination warning")).prefetch_related() - filter_names = photometry.values_list('filter__name',flat=True) - + photometry = ( + AperturePhotometry.objects.filter( + transient=transient, aperture__type__exact=aperture_type + ) + .filter(Q(is_validated="true") | Q(is_validated="contamination warning")) + .prefetch_related() + ) + filter_names = photometry.values_list("filter__name", flat=True) + if not photometry.exists(): raise ValueError(f"No host photometry of type {aperture_type}") @@ -111,7 +116,7 @@ def build_obs(transient, aperture_type, use_mag_offset=True): filters, flux_maggies, flux_maggies_error = [], [], [] - for filter,trans_curve in zip(all_filters,trans_curves): + for filter, trans_curve in zip(all_filters, trans_curves): try: if filter.name in filter_names: datapoint = photometry.get(filter=filter) @@ -167,7 +172,6 @@ def build_obs(transient, aperture_type, use_mag_offset=True): mJy_to_maggies(fluxerr_mwcorr * 10 ** (-0.4 * mag_offset)) ) - obs_data = dict( wavelength=None, spectrum=None, @@ -178,7 +182,6 @@ def build_obs(transient, aperture_type, use_mag_offset=True): filters=filters, ) - return fix_obs(obs_data) diff --git a/app/host/transient_tasks.py b/app/host/transient_tasks.py index f241f58e..e13d127f 100644 --- a/app/host/transient_tasks.py +++ b/app/host/transient_tasks.py @@ -38,9 +38,11 @@ def _prerequisites(self): """ Only prerequisite is that the host match task is not processed. """ - return {"Host match": "not processed", - "Cutout download": "processed", - "Transient MWEBV": "processed"} + return { + "Host match": "not processed", + "Cutout download": "processed", + "Transient MWEBV": "processed", + } @property def task_name(self): @@ -94,8 +96,10 @@ def _prerequisites(self): """ Only prerequisite is that the transient MWEBV task is not processed. """ - return {"Transient MWEBV": "not processed", - "Transient information": "processed"} + return { + "Transient MWEBV": "not processed", + "Transient information": "processed", + } @property def task_name(self): @@ -615,8 +619,10 @@ class TransientInformation(TransientTaskRunner): """Task Runner to gather information about the Transient""" def _prerequisites(self): - return {"Transient information": "not processed", - "Cutout download": "processed"} + return { + "Transient information": "not processed", + "Cutout download": "processed", + } @property def task_name(self): diff --git a/app/host/views.py b/app/host/views.py index caa3b5a0..6c7d2a97 100644 --- a/app/host/views.py +++ b/app/host/views.py @@ -1,3 +1,5 @@ +import time + import django_filters from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import user_passes_test @@ -34,7 +36,7 @@ from host.transient_name_server import get_transients_from_tns_by_name from revproxy.views import ProxyView from silk.profiling.profiler import silk_profile -import time + def filter_transient_categories(qs, value, task_register=None): if task_register is None: @@ -84,9 +86,9 @@ def filter_transient_categories(qs, value, task_register=None): elif value == "Finished Transients": qs = qs.filter( ~Q( - pk__in=task_register.filter( - ~Q(status__message="processed") - ).values("transient") + pk__in=task_register.filter(~Q(status__message="processed")).values( + "transient" + ) ) ) @@ -218,7 +220,6 @@ def analytics(request): def results(request, slug): - transients = Transient.objects.all() transient = transients.get(name__exact=slug) @@ -416,8 +417,10 @@ def acknowledgements(request): def home(request): analytics_results = {} - task_register_qs = TaskRegister.objects.filter(status__message='processed').prefetch_related() - for aggregate, qs_value in zip( + task_register_qs = TaskRegister.objects.filter( + status__message="processed" + ).prefetch_related() + for aggregate, qs_value in zip( [ "Basic Information", "Host Identification", @@ -432,7 +435,9 @@ def home(request): ], ): analytics_results[aggregate] = len( - filter_transient_categories(Transient.objects.all(), qs_value, task_register=task_register_qs) + filter_transient_categories( + Transient.objects.all(), qs_value, task_register=task_register_qs + ) ) # transients = TaskRegisterSnapshot.objects.filter(