-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtic_plot.py
2690 lines (2229 loc) · 101 KB
/
tic_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""
Helpers to plot the lightcurve of a TESS subject, given a
LightCurveCollection
"""
# so that type hint Optional(LC_Ylim_Func_Type) can be used
# otherwise Python complains TypeError: Cannot instantiate typing.Optional
from __future__ import annotations
import inspect
import numbers
import warnings
import re
from types import SimpleNamespace
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter, AutoMinorLocator
import matplotlib.animation as animation
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy import units as u
from astropy.table import Table
import IPython
from IPython.display import display, HTML, Audio
from ipywidgets import interactive, interactive_output, fixed
import ipywidgets as widgets
import lightkurve as lk
from lightkurve import LightCurve, LightCurveCollection, LightkurveWarning, FoldedLightCurve
from lightkurve.utils import TessQualityFlags
from lightkurve_ext import of_sectors
import lightkurve_ext as lke
import lightkurve_ext_tess as lket
# typing
from typing import Callable, Optional, Tuple
LC_Ylim_Func_Type = Callable[[LightCurve], Tuple[float, float]]
def get_tic_meta_in_html(lc, a_subject_id=None, download_dir=None, **kwargs):
# import locally so that if it fails (due to missing dependency)
# it will not break the rest of the codes
import lightkurve_ext_tess as lke_tess
return lke_tess.get_tic_meta_in_html(lc, a_subject_id=a_subject_id, download_dir=download_dir, **kwargs)
def beep():
"""Emits a beep sound. It works only in IPython / Jupyter environment only"""
# a beep to remind the users that the data has been downloaded
# css tweak to hide beep
display(
HTML(
"""<script>
function tweakCSS() {
if (document.getElementById("hide-beep-css")) {
return;
}
document.head.insertAdjacentHTML('beforeend', `<style id="hide-beep-css" type="text/css">
#beep { /* hide the audio control for the beep, generated from tplt.beep() */
width: 1px;
height: 1px;
}
</style>`);
}
tweakCSS();
</script>
"""
)
)
# the actual beep
## somehow ssl error
## beep_url = "https://upload.wikimedia.org/wikipedia/commons/f/fb/NEC_PC-9801VX_ITF_beep_sound.ogg"
beep_url = "beep_sound.ogg"
if int(re.sub(r"[.].+", "", IPython.__version__)) < 7:
# compatibility with older older IPython (e.g., google colab)
## audio = Audio(url=beep_url, autoplay=True, embed=True)
audio = Audio(filename=beep_url, autoplay=True, embed=True)
else:
## audio = Audio(url=beep_url, autoplay=True, embed=True, element_id="beep")
audio = Audio(filename=beep_url, autoplay=True, embed=True, element_id="beep")
display(audio)
def _normalize_to_percent_quiet(lc):
# Some product are in normalized flux, e.g., as 1, we still want to normalize them to percentage
# for consistency
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=LightkurveWarning, message=".*in relative units.*")
return lc.normalize(unit="percent")
# Plot the flux changes (not flux themselves) to get a sense of the rate of changes, not too helpful yet.
def plot_lcf_flux_delta(lcf, ax, xmin=None, xmax=None, moving_avg_window="30min"):
# possible input arguments
lc = _normalize_to_percent_quiet(lcf)
# Basic scatter of the observation
# ax = lc.scatter(ax=ax)
# convert to dataframe to add moving average
df = lc.to_pandas()
df["time_ts"] = [pd.Timestamp(x, unit="D") for x in df.index]
# the timestamp above is good for relative time.
# if we want the timestamp to reflect the actual time, we need to convert the BTJD in time to timetamp, e.g.
# pd.Timestamp(astropy.time.Time(x + 2457000, format='jd', scale='tdb').datetime.timestamp(), unit='s')
df["flux_mavg"] = df.rolling(moving_avg_window, on="time_ts")["flux"].mean()
# ax.plot(lc.time.value, df['flux_mavg'], c='black', label=f"Moving average ({moving_avg_window})")
df["flux_delta"] = df.rolling(moving_avg_window, on="time_ts")["flux_mavg"].apply(
lambda vary: vary[-1] - vary[0], raw=True
)
ax.plot(
lc.time.value,
df["flux_delta"],
c="blue",
label=f"Flux delta ({moving_avg_window})",
)
ax.set_xlim(xmin, xmax)
return ax
def lk_ax(*args, **kwargs):
"""Create a matplotlib figure, and return its Axes object (`gca()`) with Lightkurve style."""
with plt.style.context(lk.MPLSTYLE):
# MUST return the Axes object, rather than just the Figure object
# if only the Figure object is returned, it will have Lightkurve's default figsize
# but its style won't be used in actual plot
return plt.figure(*args, **kwargs).gca()
def flux_near(lc, time):
if time is None or lc is None:
return None
else:
idx = (np.abs(lc.time - time)).argmin()
return lc.flux[idx]
def flux_mavg_near(df, time):
if time is None or df is None:
return None
else:
idx = (np.abs(df.index.values - time)).argmin()
# must use df.iloc[idx]['flux_mavg'], rather than df['flux_mavg'][idx]
# because dataframe from lightkurve is indexed by time (rather than regular 0-based index)
# df.iloc[] ensures we can still access the value by 0-based index
return df.iloc[idx]["flux_mavg"]
def _to_unitless(n):
if hasattr(n, "value"):
return n.value
else:
return n
def as_4decimal(float_num):
if float_num is None:
return None
elif isinstance(float_num, tuple) or isinstance(float_num, list):
return [float("{0:.4f}".format(_to_unitless(n))) for n in float_num]
else:
return float("{0:.4f}".format(_to_unitless(float_num)))
def _flip_yaxis_for_mag(ax, lc, plot_kwargs):
y_column = plot_kwargs.get("column", "flux")
# invert y-axis only when it hasn't been inverted
# to support multiple scatter/plot/errorbar calls on the same ax object
if lc[y_column].unit == u.mag and ax.get_ylim()[1] > ax.get_ylim()[0]:
ax.invert_yaxis()
return ax
def _do_common_ax_tweaks(ax, lc, plot_kwargs):
if isinstance(lc, lk.FoldedLightCurve) and getattr(lc.time, "format", None) == "jd":
# the default Phase [JD] could be confusing, use
ax.set_xlabel("Phase [Days]")
if lc.flux.unit == u.mag:
ax.set_ylabel("Magnitude")
return _flip_yaxis_for_mag(ax, lc, plot_kwargs)
def scatter(lc, **kwargs):
"""lc.scatter() with the proper support of plotting flux in magnitudes"""
# shortcut to create a plot of arbitrary size without callers creating the ax.
figsize = kwargs.pop("figsize", None)
if figsize is not None:
if kwargs.get("ax", None) is not None:
raise ValueError("Only one of 'ax' and 'figsize' parameters can be specified.")
ax = lk_ax(figsize=(figsize))
kwargs["ax"] = ax
ax = lc.scatter(**kwargs)
return _do_common_ax_tweaks(ax, lc, kwargs)
def errorbar(lc, **kwargs):
"""lc.errorbar() with the proper support of plotting flux in magnitudes"""
if "marker" not in kwargs: # do not use kwargs.get("marker") is None, so that caller can explicitly set marker=None
# lightkurve's default style for errorbar hides the marker. I find having a marker is typically useful.
kwargs = kwargs.copy()
kwargs["marker"] = "o"
# shortcut to create a plot of arbitrary size without callers creating the ax.
figsize = kwargs.pop("figsize", None)
if figsize is not None:
if kwargs.get("ax", None) is not None:
raise ValueError("Only one of 'ax' and 'figsize' parameters can be specified.")
ax = lk_ax(figsize=(figsize))
kwargs["ax"] = ax
ax = lc.errorbar(**kwargs)
return _do_common_ax_tweaks(ax, lc, kwargs)
def plot(lc, **kwargs):
"""lc.plot() with the proper support of plotting flux in magnitudes"""
# shortcut to create a plot of arbitrary size without callers creating the ax.
figsize = kwargs.pop("figsize", None)
if figsize is not None:
if kwargs.get("ax", None) is not None:
raise ValueError("Only one of 'ax' and 'figsize' parameters can be specified.")
ax = lk_ax(figsize=(figsize))
kwargs["ax"] = ax
ax = lc.plot(**kwargs)
return _do_common_ax_tweaks(ax, lc, kwargs)
def add_flux_moving_average(lc, moving_avg_window):
# include minimal columns: reduce df size and make it compatible with eleanor full FITS files
df = lc["time", "flux", "flux_err"].to_pandas()
begin_t = df.index[0]
df["time_ts"] = [pd.Timestamp(t - begin_t, unit="D") for t in df.index]
# the timestamp above is good for relative time.
# 1. we subtract the time with the timestamp because for some products, e.g., CDIPS, the time value itself
# is so large that creating pd.Timestamp with it causes Overflow error
# 2. if we want the timestamp to reflect the actual time, we need to convert the BTJD in time to timetamp, e.g.
# pd.Timestamp(astropy.time.Time(x + 2457000, format='jd', scale='tdb').datetime.timestamp(), unit='s')
df["flux_mavg"] = df.rolling(moving_avg_window, center=True, on="time_ts")["flux"].mean()
return df
def add_relative_time(lc, lcf):
t_start = lcf.meta.get("TSTART")
if t_start is None:
return False
lc["time_rel"] = lc.time - t_start
return True
def mask_gap(x, y, min_x_diff):
"""
Help to plot graphs with gaps in the data, so that straight line won't be draw to fill the gap.
Return a masked y that can be passed to pyplot.plot() that can show the gap.
"""
# handle case that x is a astropy Time object, rather than simple float array
x = _to_unitless(x)
x_diff = np.diff(x, prepend=-min_x_diff)
return np.ma.masked_where(x_diff > min_x_diff, y)
def normalize_percent(lc):
"""
A syntactic surgar for lambda for normalize as percentage.
Useful when calling ``lc.fold()``, ``tpf.interact()``, etc.
"""
return lc.normalize(unit="percent")
def _add_flux_origin_to_ylabel(ax, lc):
# e.g., QLP uses sap_flux as the standard.
# it needs to support other products too
standard_flux_col_map = {
"SPOC": "pdcsap_flux",
"TESS-SPOC": "pdcsap_flux",
"QLP": "sap_flux",
"TASOC": "flux_raw",
"CDIPS": "irm1",
"PATHOS": "psf_flux_cor",
}
def make_italic(text):
# convert the text to latex italic expression
return r"$\it{" + text.replace("_", r"\_") + "}$"
flux_origin = lc.meta.get("FLUX_ORIGIN", None)
if flux_origin is not None and flux_origin != standard_flux_col_map.get(lc.meta.get("AUTHOR", None), None):
ax.yaxis.set_label_text(ax.yaxis.get_label_text().replace("Flux", make_italic(lc.flux_origin)))
_cache_plot_n_annotate_lcf = dict(lcf=None, flux_col=None, normalize=None, lc=None)
def plot_n_annotate_lcf(
lcf,
ax,
flux_col="flux",
xmin=None,
xmax=None,
truncate_extra_buffer_time=0.0,
t0=None,
t_start=None,
t_end=None,
moving_avg_window="30min",
t0mark_ymax=0.3,
mark_momentum_dumps=True,
set_title=True,
show_r_obj_estimate=True,
title_fontsize=18,
t0_label_suffix=None,
normalize=True,
lc_tweak_fn=None,
ax_tweak_fn=None,
plot_fn_name="scatter",
plot_kwargs=dict(),
legend_kwargs=dict(),
):
if lcf is None:
print("Warning: lcf is None. Plot skipped")
return
# cache lc to speed up plots repeatedly over the same lcf
global _cache_plot_n_annotate_lcf
if (
lcf is _cache_plot_n_annotate_lcf["lcf"]
and flux_col == _cache_plot_n_annotate_lcf["flux_col"]
and normalize == _cache_plot_n_annotate_lcf["normalize"]
):
lc = _cache_plot_n_annotate_lcf["lc"]
else:
lc = lke.select_flux(lcf, flux_col)
if normalize:
lc = _normalize_to_percent_quiet(lc)
_cache_plot_n_annotate_lcf["lcf"] = lcf
_cache_plot_n_annotate_lcf["flux_col"] = flux_col
_cache_plot_n_annotate_lcf["normalize"] = normalize
_cache_plot_n_annotate_lcf["lc"] = lc
# create a copy that will be modified, to avoid unintentional side effects,
# in case caller uses the same instance repeatedly.
plot_kwargs = plot_kwargs.copy()
if xmin is None and t_start is not None:
xmin = t_start - 0.5
if xmax is None and t_end is not None:
xmax = t_end + 0.5
# truncate the LC approximately around xmin/xmax, so that
# - the Y-scale will then automatically scaled to the specified time range, rather than over entire lightcurve
# - make plotting faster (fewer data points)
# - Some extra buffer is added so that if users want to tweak xmin/xmax afterwards
# the data right outside xmin/xmax range would still be available for plots.
if xmin is not None or xmax is not None:
trunc_min = xmin - truncate_extra_buffer_time if xmin is not None else None
trunc_max = xmax + truncate_extra_buffer_time if xmax is not None else None
lc = lc.truncate(trunc_min, trunc_max)
if lc_tweak_fn is not None:
lc = lc_tweak_fn(lc)
lcfh = lcf.meta
# Basic scatter of the observation
plot_kwargs["ax"] = ax
if plot_fn_name == "scatter" and lke.estimate_cadence(lc, unit=u.s) > 300 * u.s:
# long cadence has more spare data, use a larger "x" to represent them
# "x" is also useful to distinguish it from moving average,
# which will likely overlap with the points given the sparse data
if plot_kwargs.get("s") is None:
plot_kwargs["s"] = 36
if plot_kwargs.get("marker") is None:
plot_kwargs["marker"] = "x"
plot_fn = globals()[plot_fn_name] # the scatter / plot / errorbar wrapper in this module
ax = plot_fn(lc, **plot_kwargs)
if len(lc) < 1:
print(
(
"Warning: specified (xmin, xmax) is out of the range of the lightcurve "
f"{lc.label} sector {lcfh['SECTOR']}. Nothing to plot"
)
)
return ax
# convert to dataframe to add moving average
if moving_avg_window is not None:
df = add_flux_moving_average(lc, moving_avg_window)
# mask_gap: if there is a gap larger than 2 hours,
# show the gap rather than trying to fill the gap with a straight line.
ax.plot(
lc.time.value,
mask_gap(lc.time, df["flux_mavg"], 2 / 24),
c="#3AF",
label=f"Moving average ({moving_avg_window})",
)
else:
df = add_flux_moving_average(lc, "10min") # still needed for some subsequent calc, but don't plot it
# annotate the graph
if t_start is not None:
ax.axvline(t_start)
if t_end is not None:
ax.axvline(t_end)
if t0 is not None:
t_lc_start = lcf.meta.get("TSTART", None)
t0_rel_text = ""
if t_lc_start is not None:
t0_rel = t0 - t_lc_start
t0_rel_text = f" ({t0_rel:.3f})"
label_vline = f"t0 ~= {t0:.3f}{t0_rel_text}"
if t0_label_suffix is not None:
label_vline = f"{label_vline}\n{t0_label_suffix}"
ax.axvline(
t0,
ymin=0,
ymax=t0mark_ymax,
color="black",
linewidth=3,
linestyle="--",
label=label_vline,
)
if mark_momentum_dumps:
plot_momentum_dumps(lc, ax)
if set_title:
title_text = lc.label
if len(lcfh.get("SECTORS", [])) > 1:
sector_text = lke.abbrev_sector_list(lcfh.get("SECTORS", []))
else:
sector_text = lcfh.get("SECTOR", None)
if sector_text is not None:
title_text += f", sector {sector_text}"
author = lc.meta.get("AUTHOR", None)
if author is not None and author != "SPOC":
title_text += f", by {author}"
if t0 is not None:
transit_duration_msg = ""
if t_start is not None and t_end is not None:
transit_duration_msg = f"\ntransit duration ~= {as_4decimal(24 * (t_end - t_start))}h"
flux_t0 = flux_mavg_near(df, t0)
if flux_t0 is not None:
flux_begin = max(flux_mavg_near(df, t_start), flux_mavg_near(df, t_end))
flux_dip = flux_begin - flux_t0
r_obj_msg = ""
r_obj = lke.estimate_object_radius_in_r_jupiter(lc, flux_dip / 100) # convert flux_dip in percent to fractions
if show_r_obj_estimate and r_obj is not None:
r_obj_msg = f", R_p ~= {r_obj:0.2f} R_j"
title_text += (
f" \nflux@$t_0$ ~= {as_4decimal(flux_t0)}%, "
f"dip ~= {as_4decimal(flux_dip)}%{r_obj_msg}{transit_duration_msg}"
)
ax.set_title(title_text, {"fontsize": title_fontsize})
ax.legend(**legend_kwargs)
ax.set_xlim(xmin, xmax)
_add_flux_origin_to_ylabel(ax, lc)
ax.xaxis.label.set_size(18)
ax.yaxis.label.set_size(18)
# to avoid occasional formating in scientific notations
ax.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.tick_params(axis="x", which="minor", length=4)
ax.yaxis.set_minor_locator(AutoMinorLocator())
ax.tick_params(axis="y", which="minor", length=4)
if ax_tweak_fn is not None:
ax_tweak_fn(ax)
return ax
def _get_plot_transit_x_range(t0, duration, surround_time):
return t0 - (duration + surround_time) / 2, t0 + (duration + surround_time) / 2
def plot_transit(lcf, ax, t0, duration, surround_time, **kwargs):
xmin, xmax = _get_plot_transit_x_range(t0, duration, surround_time)
return plot_n_annotate_lcf(
lcf,
ax=ax,
t0=t0 if duration > 0 else None,
t_start=t0 - duration / 2 if duration > 0 else None,
t_end=t0 + duration / 2 if duration > 0 else None,
xmin=xmin,
xmax=xmax,
**kwargs,
)
def plot_transits(lcf_coll, transit_specs, ax_fn=lambda: lk_ax(), **kwargs):
"""Helper to plot transits zoomed-in."""
flux_col = kwargs.get("flux_col", "flux")
if not isinstance(flux_col, str) or flux_col.lower() not in ["flux", "pdcsap_flux"]:
display(HTML(f"""<span style="background-color: yellow"> Note: </span> Not standard flux is plotted: {flux_col}"""))
axs = []
for spec in transit_specs:
for lcf in of_sectors(lcf_coll, spec["sector"]): # in case we have multiple lcf per sector
# process the supplied spec and apply defaults
t0 = spec["epoch"]
duration = spec["duration_hr"] / 24
period = spec["period"]
steps_to_show = spec["steps_to_show"]
surround_time = spec.get("surround_time", 1.5) # a hardcoded last resort default
# TODO: warn if period is 0, but steps to show is not [0]
for i in steps_to_show:
cur_t0 = t0 + period * i
t0_label_suffix = None
if spec.get("label", "") not in ["", "dip", "dips"]:
t0_label_suffix = spec.get("label")
ax = plot_transit(lcf, ax_fn(), cur_t0, duration, surround_time, t0_label_suffix=t0_label_suffix, **kwargs)
axs.append(ax)
return axs
def print_data_range(lcf_coll):
"""Print the data range for the given LightCurveCollection
For each LightCurveFile:
* sector start/stop time
* first / last observation time
* camera used
"""
html = '<pre style="line-height: 1.1;">\n'
html += "<summary>Sectors: " + lke.abbrev_sector_list(lcf_coll) + f" ({len(lcf_coll)})" + "\n"
html += "Observation period range / data range:" + "\n"
html += "<details>"
for lc in lcf_coll:
html += f" Sector {lc.meta.get('SECTOR')}: {lc.meta.get('TSTART')} - {lc.meta.get('TSTOP')}" + "\n"
html += f" (cam.ccd {lc.meta.get('CAMERA')}.{lc.meta.get('CCD')}) {lc.time.min()} - {lc.time.max()}" + "\n"
html += "</details></summary></pre>"
display(HTML(html))
def get_momentum_dump_times(lcf):
time_mom_dumps = lcf.meta.get("momentum_dumps", None)
if time_mom_dumps is None:
time_mom_dumps = lket.MomentumDumpsAccessor.get_in_range(lcf)
lcf.meta["momentum_dumps"] = time_mom_dumps
# in case the lcf has been truncated, we preserve the truncation
return time_mom_dumps[(lcf.time.min().value <= time_mom_dumps) & (time_mom_dumps <= lcf.time.max().value)]
def vlines_y_in_axes_coord(ax, x, ymin, ymax, **kwargs):
"""Wrapper over `Axes.vlines()` for cases where ymin/ymax are in axes coordinates.
It is to workaround bug: https://github.com/matplotlib/matplotlib/issues/23171
"""
ybottom, ytop = ax.get_ylim() # saved for workaround
trans = ax.get_xaxis_transform() # for ymin/ymax in axes coordinates
if kwargs.get("transform", None) is not None and kwargs["transform"] is not trans:
raise ValueError("_vlines_y_in_axes_coord() does not accept transform() parameter. It uses its own")
kwargs["transform"] = trans
res = ax.vlines(x, ymin, ymax, **kwargs)
# Applying workaround: the bug scaled the ax's y-axis incorrectly, we compensate it by
# rescaling it using the saved ylim
#
# edge case: if ymax > 1, we need to rescale the ytop
if ymax > 1:
ytop = ytop * ymax
ax.set_ylim(ybottom, ytop)
return res
def plot_momentum_dumps(lcf, ax, use_relative_time=False, mark_height_scale=0.15, color="red"):
"""Mark momentum dumps on the given plot."""
# The momentum dump is for TESS data, in btjd
if lcf.time.format != "btjd":
return ax
time_mom_dumps = get_momentum_dump_times(lcf)
if len(time_mom_dumps) < 1:
return ax
# case have data to plot
if use_relative_time:
t_start = lcf.meta.get("TSTART")
time_mom_dumps = time_mom_dumps - t_start
vlines_y_in_axes_coord(
ax,
time_mom_dumps,
ymin=0,
ymax=mark_height_scale,
color=color,
linewidth=1,
linestyle="-.",
label="Momentum dumps",
)
return ax
# Do the actual plots
def plot_all(
lcf_coll,
flux_col="flux",
moving_avg_window=None,
normalize=True,
lc_tweak_fn=None,
ax_fn=None,
use_relative_time=False,
mark_quality_issues=True,
mark_momentum_dumps=True,
set_title=True,
ax_tweak_fn=None,
plot_fn_name="scatter",
plot_kwargs=None,
):
"""Plot the given LightCurveFile collection, one graph for each LightCurve
Returns
-------
axs : the list of plots in `matplotlib.Axes`
"""
# choice 1: use the built-in plot method
# ax_all = plt.figure(figsize=(30, 15)).gca()
# lcf_coll.PDCSAP_FLUX.plot(ax=ax_all) # Or lcf_coll.SAP_FLUX.plot()
# choice 2: stitch lightcurves of the collection together, and then use more flexible methods, e.g., scatter
# Note: pass lambda x: x to stitch() so that the code won't normalize the flux value sector by sector
# lc_all = lcf_coll.PDCSAP_FLUX.stitch(lambda x: x)
# lc_all.scatter(ax=ax_all, normalize=True)
# choice 3: plot the lightcurve sector by sector: each sector has its own color
# for i in range(0, len(lcf_coll)):
# lcf_coll[i].PDCSAP_FLUX.scatter(ax=ax_all)
# ax_all.set_title((f"TIC {lcf_coll[0].PDCSAP_FLUX.label}, "
# f"sectors {list(map(lambda lcf: lcf.meta.get('SECTOR'), lcf_coll))}"))
# return ax_all
# choice 4: plot the lightcurve sector by sector: each sector in its own graph
axs = []
for i in range(0, len(lcf_coll)):
if ax_fn is None:
ax = lk_ax()
else:
ax = ax_fn()
lcf = lcf_coll[i]
lc = lke.select_flux(lcf, flux_col)
if normalize:
lc = _normalize_to_percent_quiet(lc)
else:
# use a copy to avoid side effects, e.g., we modify its label below
# (if normalized, it'd be a different LC object anyway)
lc = lc.copy()
if lc_tweak_fn is not None:
lc = lc_tweak_fn(lc)
# temporarily change time to a relative one if specified
if use_relative_time:
rel_time_added = add_relative_time(lc, lcf)
if rel_time_added:
lc["time_orig"] = lc.time
lc.time = lc.time_rel
else:
# the file has no observation start time, so we cannot add it
use_relative_time = False
# tweak label to include sector if any
sector = lcf_coll[i].meta.get("SECTOR", None)
label_long = lc.label
if sector is not None:
lc.label += f", s.{sector}"
label_long += f", sector {sector}"
if lc.author is not None and lc.author != "SPOC":
label_long += f", by {lc.author}"
# Note: each LC has its own copy of plot_kwargs, so that they can be customized individually
plot_kwargs_for_cur_lc = plot_kwargs
if plot_kwargs_for_cur_lc is None:
if plot_fn_name != "scatter":
plot_kwargs_for_cur_lc = dict()
# the defaults are for scatter only
elif lke.estimate_cadence(lc, unit=u.s) > 300 * u.s:
# long cadence has more spare data, use a larger "x" to represent them
# "x" is also useful to distinguish it from moving average,
# which will likely overlap with the points given the sparse data
plot_kwargs_for_cur_lc = dict(s=16, marker="x")
else:
# for typical short cadence data, make dots smaller than the default (s=4)
# so that the output doesn't look overly dense
# for the purpose of plotting, FFI-based data with short enough cadence (e.g., 200s)
# are considered short
plot_kwargs_for_cur_lc = dict(s=0.5)
else: # case user-supplied plot args
plot_kwargs_for_cur_lc = plot_kwargs_for_cur_lc.copy() # we'll modify it
plot_kwargs_for_cur_lc["ax"] = ax # add the ax we created
plot_fn = globals()[plot_fn_name] # the scatter / plot / errorbar wrapper in this module
ax = plot_fn(lc, **plot_kwargs_for_cur_lc)
# convert to dataframe to add moving average
if moving_avg_window is not None:
df = add_flux_moving_average(lc, moving_avg_window)
# mask_gap: if there is a gap larger than 2 hours,
# show the gap rather than trying to fill the gap with a straight line.
ax.plot(
lc.time.value,
mask_gap(lc.time, df["flux_mavg"], 2 / 24),
c="#3AF",
lw=0.4,
label=f"Moving average ({moving_avg_window})",
)
title_extras = ""
if lc_tweak_fn is not None:
title_extras = "\nLC tweaked, e.g., outliers removed"
if set_title:
ax.set_title(f"{label_long} {title_extras}") # {"fontsize": 18}
if use_relative_time:
ax.xaxis.set_label_text("Time - relative")
# restore original time after plot is done
lc.time = lc.time_orig
else:
t_start = lc.meta.get("TSTART")
if t_start is not None:
ax.xaxis.set_label_text(ax.xaxis.label.get_text() + f", TSTART={t_start:0.2f}")
_add_flux_origin_to_ylabel(ax, lc)
# to avoid occasional formating in scientific notations
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
# minor tick, 1 day interval in practice
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.tick_params(axis="x", which="minor", length=4)
# ax.xaxis.grid(True, which='minor') # too noisy to be there by default
ax.xaxis.label.set_size(fontsize=18)
ax.yaxis.label.set_size(fontsize=18)
if ax_tweak_fn is not None:
ax_tweak_fn(ax)
# mark quality issue is applied after ax_tweak_fn, in case users use ax_tweak_fn and change the graph's ylim
if mark_quality_issues:
# the time where flux might have potential issues, using the suggested starting quality flag mask
time = lc.time if not use_relative_time else lc.time_rel
time_w_quality_issues = time[lke.create_quality_issues_mask(lc)]
if len(time_w_quality_issues) > 0:
# add marks as vertical lines at bottom 10% of the plot
vlines_y_in_axes_coord(
ax,
time_w_quality_issues.value,
ymin=0,
ymax=0.1,
color="red",
linewidth=1,
linestyle="--",
label="potential quality issue",
)
if mark_momentum_dumps:
plot_momentum_dumps(lcf, ax, use_relative_time=use_relative_time)
ax.legend()
axs.append(ax)
return axs
_lcf_4_plot_interactive = None
def _update_plot_lcf_interactive(figsize, flux_col, xrange, moving_avg_window, ymin, ymax, widget_out2):
# use global to accept lct
global _lcf_4_plot_interactive
lcf = _lcf_4_plot_interactive
ax = lk_ax(figsize=figsize)
plot_n_annotate_lcf(
lcf,
ax,
flux_col=flux_col,
xmin=xrange[0],
xmax=xrange[1],
moving_avg_window=moving_avg_window,
)
codes_text = f"ax.set_xlim({xrange[0]}, {xrange[1]})"
ymin_to_use = ymin if ymin >= 0 else None
ymax_to_use = ymax if ymax >= 0 else None
if (ymin_to_use is not None) or (ymax_to_use is not None):
ax.set_ylim(ymin_to_use, ymax_to_use)
codes_text += f"\n\nax.set_ylim({ymin_to_use}, {ymax_to_use})"
widget_out2.clear_output()
with widget_out2:
print(codes_text)
return None
def plot_lcf_interactive(lcf, figsize=(15, 8), flux_col="flux"):
desc_style = {"description_width": "25ch"}
slider_style = {"description_width": "25ch"}
slider_layout = {"width": "100ch"}
t_start = lcf.meta.get("TSTART")
t_stop = lcf.meta.get("TSTOP")
# Add a second output for textual
widget_out2 = widgets.Output()
# pass lcf with a global rather than the slow fixed(lcf) with lkv2
#
# import warnings
# with warnings.catch_warnings():
# # lkv2 workaround: to suppress astropy table warning, stating that the semantics of == will be changed in the future.
# warnings.filterwarnings("ignore", category=FutureWarning)
# fixed_lcf = fixed(lcf)
global _lcf_4_plot_interactive
_lcf_4_plot_interactive = lcf
w = interactive(
_update_plot_lcf_interactive,
figsize=fixed(figsize),
# lcf = fixed_lcf,
flux_col=fixed(flux_col),
xrange=widgets.FloatRangeSlider(
min=t_start,
max=t_stop,
step=0.1,
value=(t_start, t_stop),
description="Time",
continuous_update=False,
readout_format=".1f",
layout=slider_layout,
style=slider_style,
),
moving_avg_window=widgets.Dropdown(
options=[
("None", None),
("10 min", "20min"),
("20 min", "20min"),
("30 min", "30min"),
("1 hour", "1h"),
("2 hours", "2h"),
("4 hours", "4h"),
],
value="30min",
description="Moving average window",
style=desc_style,
),
ymin=widgets.FloatText(value=-1, description="Flux min, -1 for default", style=desc_style),
ymax=widgets.FloatText(value=-1, description="Flux max, -1 for default", style=desc_style),
widget_out2=fixed(widget_out2),
)
w.layout.border = "1px solid lightgray"
w.layout.padding = "1em 0px"
widget_out2.layout.padding = "1em"
w.children = w.children + (widget_out2,)
display(w)
return w
def number_of_decimal_places(num, min=0):
if isinstance(num, int):
return min
if isinstance(num, (float, str)):
# https://stackoverflow.com/a/26231848
num_decimals = str(num)[::-1].find(".")
return max(num_decimals, min)
raise TypeError(f"num must be a number or number-like string. Actual: {type(num)}")
def plot_transit_interactive(lcf, figsize=(15, 8), flux_col="flux", defaults=None):
# keep track some of the UI inputs to determine user's intention.
last_t0, last_step = None, None # to be inited right before creating interactive UI
codes_transit_spec = ""
def _update_plot_transit_interactive(
flux_col,
t0,
duration_hr,
period,
step,
surround_time,
moving_avg_window,
t0mark_ymax,
ymin,
ymax,
widget_out2,
widget_t0,
widget_step,
):
nonlocal last_t0, last_step
nonlocal codes_transit_spec
# for typical inline matplotlib backend, the figure needs to be recreated every time.
ax = lk_ax(figsize=figsize)
codes_text = "# Snippets to generate the plot"
moving_avg_window_for_codes = "None" if moving_avg_window is None else f"'{moving_avg_window}'"
if t0 < 0:
plot_n_annotate_lcf(lcf, ax, flux_col=flux_col, moving_avg_window=moving_avg_window)
codes_text += f"\nplot_n_annotate_lcf(lcf, ax, moving_avg_window={moving_avg_window_for_codes})"
else:
t0_to_use = t0 + step * period
# Possible auto adjustment of step
# case 1. handle case user changes t0,
# - prevent t0 goes out of range
# - reset the step to 0, otherwise
if t0 != last_t0:
# see if new t0 is in range
# - use t0 rather than t0_to_use, as we're going to reset step to 0
xmin, xmax = _get_plot_transit_x_range(t0, duration_hr / 24, surround_time)
times = lcf.time.value
no_data = len(times[(xmin <= times) & (times <= xmax)]) < 1
if no_data: # reset to last t0
t0 = last_t0
t0_to_use = t0 + step * period
widget_t0.value = t0
else: # use the new t0, reset step to 0.
step = 0
t0_to_use = t0
widget_step.value = step
# case 2. handle case user changes step, automatically adjust the step
# if the implied range has no LC data (data gap, before start, after end)
elif step != last_step:
xmin, xmax = _get_plot_transit_x_range(t0_to_use, duration_hr / 24, surround_time)
times = lcf.time.value
no_data = len(times[(xmin <= times) & (times <= xmax)]) < 1
if no_data:
if step > last_step: # pan forward
times_avail = times[times > xmax]
if len(times_avail) > 0:
xmin_avail = times_avail[0]
step = np.ceil((xmin_avail - t0) / period)
else: # reach the end of the LC, go no further
step = last_step
t0_to_use = t0 + step * period
widget_step.value = step
else: # pan backward
times_avail = times[times < xmin]
if len(times_avail) > 0:
xmax_avail = times_avail[-1]
step = np.floor((xmax_avail - t0) / period)
else: # reach the end of the LC, go no further
step = last_step
t0_to_use = t0 + step * period
widget_step.value = step
# okay, now record what the parameters about to plot
# after the auto adjustment is done
last_t0, last_step = t0, step
plot_transit(
lcf,
ax,
t0_to_use,
duration_hr / 24,
surround_time,
flux_col=flux_col,
moving_avg_window=moving_avg_window,
t0mark_ymax=t0mark_ymax,
# fix legend to upper left to avoid clashing with the notebook nav at the upper right
legend_kwargs=dict(loc="upper left"),
)
codes_transit_spec = f"""# transit_specs for calling plot_transits()
transit_specs = TransitTimeSpecList( # {lcf.meta.get("LABEL")}
dict(epoch={t0}, duration_hr={duration_hr}, period={period}, label="dip",
sector={lcf.meta.get('SECTOR')}, steps_to_show=[{step}],
),
defaults=dict(surround_time={surround_time})
)
"""
codes_text += f"""
# transit parameters - t0: BTJD {t0}, duration: {duration_hr} hours, period: {period} days
plot_transit(lcf, ax, {t0_to_use}, {duration_hr} / 24, {surround_time}, \
moving_avg_window={moving_avg_window_for_codes}, t0mark_ymax={t0mark_ymax})
{codes_transit_spec}
"""
ymin_to_use = ymin if ymin >= 0 else None
ymax_to_use = ymax if ymax >= 0 else None
if (ymin_to_use is not None) or (ymax_to_use is not None):