Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new material comparison script #3389

Merged
356 changes: 121 additions & 235 deletions Examples/Scripts/MaterialMapping/material_comparison.py
Original file line number Diff line number Diff line change
@@ -1,239 +1,125 @@
import ROOT
import argparse
import matplotlib.pyplot as plt
import numpy as np
import uproot
import math
from collections import namedtuple

from ROOT import (
TCanvas,
TFile,
TTree,
gDirectory,
gStyle,
TH1D,
TH2F,
TProfile,
TRatioPlot,
FileRecord = namedtuple("FileRecord", ["name", "tree", "label", "color", "marker_size"])
PlotRecord = namedtuple(
"PlotRecord", ["x_axis", "y_axis", "x_range", "x_bins", "x_label", "saveAs"]
)


def TH1D_from_TProf(tprof):
h = TH1D(
tprof.GetName() + "_th1",
tprof.GetTitle(),
tprof.GetNbinsX(),
tprof.GetXaxis().GetXmin(),
tprof.GetXaxis().GetXmax(),
)
for i in range(tprof.GetNbinsX()):
if tprof.GetBinContent(i + 1) == 0.0:
h.SetBinContent(i + 1, 0.0)
h.SetBinError(i + 1, 1000.0)
continue
h.SetBinContent(i + 1, tprof.GetBinContent(i + 1))
h.SetBinError(i + 1, tprof.GetBinError(i + 1))
return h


if "__main__" == __name__:
p = argparse.ArgumentParser()

p.add_argument(
"-e", "--entries", type=int, default=100000, help="Number of events to process"
)
p.add_argument(
"-i",
"--input",
type=str,
nargs="+",
default="",
help="Input files with material tracks",
)
p.add_argument(
"-o",
"--output",
type=str,
default="",
help="Output file with produced material overview plots",
)
p.add_argument(
"-l",
"--labels",
type=str,
nargs="+",
default="",
help="The labels for the input files",
)
p.add_argument(
"-v",
"--variables",
type=str,
nargs="+",
default=["t_X0", "t_L0"],
help="The variables to be plotted",
)
p.add_argument(
"-m",
"--max",
type=float,
nargs="+",
default=[4.5, 1.5],
help="The variables to be plotted",
)
p.add_argument(
"-a",
"--axes",
type=str,
nargs="+",
default=["v_eta", "v_phi"],
help="The axes versus which the variables to be plotted",
)
p.add_argument(
"--eta",
type=float,
nargs=2,
default=[-4.0, 4.0],
help="Eta range for the plotting",
)
p.add_argument("--eta-bins", type=int, default=60, help="Eta bins for the plotting")
p.add_argument(
"--phi",
type=float,
nargs=2,
default=[-math.pi, math.pi],
help="Phi range for the plotting",
)
p.add_argument("--phi-bins", type=int, default=72, help="Phi bins for the plotting")
p.add_argument(
"-d",
"--detray-input",
type=str,
default="",
help="Optional detray input csv file",
)

args = p.parse_args()

ttrees = []

if len(args.input) != len(args.labels):
print("** ERROR ** The number of input files and labels must match")
exit(1)

# The histogram map
h_dict = {}
tfiles = []

ofile = ROOT.TFile.Open(args.output, "RECREATE") if args.output else None

input_files = args.input

# Detray csv file
if args.detray_input:
input_files.append(args.detray_input)

# Loop over the files and create the comparison histograms
for fi, ifile in enumerate(input_files):
# Special treament for detray input
if ifile == args.detray_input:
print("Reading detray input from file: ", args.detray_input)
ttree = TTree("t", "material read from detray")
ttree.ReadFile(args.detray_input, "v_eta/D:v_phi:t_X0:t_L0:t_X0:t_L0")
lbl = "detray"
else:
# Get the file
tfile = ROOT.TFile.Open(ifile)
tfiles.append(tfile)
ttree = tfile.Get("material-tracks")
ttrees.append(ttree)
# The label
lbl = args.labels[fi]

# Loop over the variables and axes
for iv, var in enumerate(args.variables):
for ax in args.axes:
if ax == "v_eta":
bins = args.eta_bins
low = args.eta[0]
high = args.eta[1]
cut = f"v_phi > {args.phi[0]} && v_phi < {args.phi[1]}"
elif ax == "v_phi":
bins = args.phi_bins
low = args.phi[0]
high = args.phi[1]
cut = f"v_eta > {args.eta[0]} && v_eta < {args.eta[1]}"
# Some naming magic
hcmmd = f"{var}:{ax}>>"
hname = f"{var}_vs_{ax}"
hfname = f"{lbl}_{var}_vs_{ax}"
hrange = f"({bins},{low},{high},100,0,{args.max[iv]})"
htitle = f"{var} vs {ax}"
ttree.Draw(hcmmd + hfname + hrange, cut, "")
h = ROOT.gDirectory.Get(hfname)
# Fill into comparison histogram
if h_dict.get(hname):
h_dict[hname].append(h)
else:
h_dict[hname] = [h]

# Write to file
if ofile:
ofile.cd()
h.Write()

colors = [
ROOT.kBlack,
ROOT.kRed + 2,
ROOT.kBlue - 4,
ROOT.kGreen + 1,
ROOT.kYellow - 2,
]
markers = [
ROOT.kFullCircle,
ROOT.kFullCircle,
ROOT.kFullTriangleUp,
ROOT.kFullTriangleDown,
ROOT.kFullDiamond,
]

# Now create the comparison histograms
c = ROOT.TCanvas("Comparison", "Comparison", 1200, 800)
c.Divide(2, 2)

# Remove the stat box
gStyle.SetOptStat(0)

# Memory garbage collection, thanks ROOT
hist_memory_pool = []

ic = 0
for hname in h_dict:
ic += 1
c.cd(ic)
h_list = h_dict[hname]
h_ref = None
h_prof_ref = None
for ih, h in enumerate(h_list):
h_prof = TH1D_from_TProf(h.ProfileX())
hist_memory_pool.append(h_prof)
h_prof.SetObjectStat(0)
h_prof.SetLineColor(colors[ih])
h_prof.SetMarkerColor(colors[ih])
h_prof.SetMarkerSize(0.5)
h_prof.SetMarkerStyle(markers[ih])
h_prof.GetYaxis().SetRangeUser(0.0, 1.3 * h_prof.GetMaximum())
if ih == 0:
h_ref = h
h_prof_ref = h_prof
else:
h_ratio = ROOT.TRatioPlot(h_prof_ref, h_prof)
h_ratio.SetGraphDrawOpt("pe")
h_ratio.SetSeparationMargin(0.005)
drawOption = "e,same" if ih > 1 else "e"
h_ratio.Draw(drawOption)
h_ratio.GetLowerRefGraph().SetLineColor(colors[ih])
h_ratio.GetLowerRefGraph().SetMarkerColor(colors[ih])
h_ratio.GetLowerRefGraph().SetMarkerSize(0.5)
h_ratio.GetLowerRefGraph().SetMarkerStyle(markers[ih])
c.Update()
hist_memory_pool.append(h_ratio)
h_prof.Draw(("same" if ih > 0 else ""))
c.SaveAs(f"{hname}.png")
# The file records
fileRecords = [
FileRecord("geant4_material_tracks.root", "material-tracks", "Geant4", "blue", 3),
FileRecord("acts_material_tracks.root", "material-tracks", "Acts", "orange", 4),
]


# The plot records
plotRecords = [
PlotRecord("v_eta", "t_X0", (-4.0, 4.0), 80, "η", "tX0_vs_eta.svg"),
PlotRecord("v_phi", "t_X0", (-math.pi, math.pi), 72, "φ", "tX0_vs_phi.svg"),
]

# Different plot records
for pr in plotRecords:

fig, axs = plt.subplots(2, 1, sharex=True)
fig.subplots_adjust(hspace=0.05)

# Prepare limit & ratios
y_lim = 0
y_ratio_values = []
y_ratio_errors = [0.0 for i in range(pr.x_bins)]

# Loop over the file records
for ifr, fr in enumerate(fileRecords):

# Load the three
tree = uproot.open(fr.name + ":" + fr.tree)

x_arr = tree[pr.x_axis].array(library="np")
y_arr = tree[pr.y_axis].array(library="np")
y_max = y_arr.max()
y_lim = y_max if y_max > y_lim else y_lim

# Generate the central bin values
x_step = (pr.x_range[1] - pr.x_range[0]) / pr.x_bins
x_vals = [pr.x_range[0] + (ix + 0.5) * x_step for ix in range(pr.x_bins)]

# Prepare the min /max
y_min_vals = [1000.0] * pr.x_bins
y_max_vals = [0.0] * pr.x_bins
y_vals_sorted = [np.array([])] * pr.x_bins

for iv in range(len(x_arr)):
x_b = int((x_arr[iv] - pr.x_range[0]) / x_step)
y_v = y_arr[iv]
# Take min / max
y_min_vals[x_b] = y_v if y_v < y_min_vals[x_b] else y_min_vals[x_b]
y_max_vals[x_b] = y_v if y_v > y_max_vals[x_b] else y_max_vals[x_b]
# Regulate the x value
y_vals_sorted[x_b] = np.append(y_vals_sorted[x_b], y_v)

axs[0].fill_between(
x=x_vals,
y1=y_min_vals,
y2=y_max_vals,
alpha=0.1,
label=fr.label + " spread",
color=fr.color,
)
axs[0].grid(axis="x")
y_vals_mean = [y_bin_vals.mean() for y_bin_vals in y_vals_sorted]

y_ratio_values += [y_vals_mean]
y_vals_mse = [
y_bin_vals.std() ** 2 / len(y_bin_vals) for y_bin_vals in y_vals_sorted
]

axs[0].errorbar(
x=x_vals,
y=y_vals_mean,
yerr=y_vals_mse,
markersize=fr.marker_size,
marker="o",
mfc=fr.color if ifr == 0 else "none",
linestyle="none",
label=fr.label + " mean",
color=fr.color,
)

if ifr > 0:
y_ratios = [
y_ratio_values[ifr][ib] / y_ratio_values[0][ib]
for ib in range(pr.x_bins)
]
axs[1].errorbar(
x=x_vals,
y=y_ratios,
yerr=y_ratio_errors,
markersize=fr.marker_size,
marker="o",
mfc="none",
linestyle="none",
color=fr.color,
label=fr.label,
)
axs[1].set_ylabel("Ratio to " + fileRecords[0].label)

# Some final cosmetics
axs[0].set_ylim(0.0, y_lim)
axs[0].grid(axis="x", linestyle="dotted")

axs[1].set_ylim(0.9, 1.1)
axs[1].grid(axis="x", linestyle="dotted")
axs[1].axhline(y=1.0, color="black", linestyle="-")

axs[0].legend(loc="upper center")
axs[1].legend(loc="upper center")

# Set the range of x-axis
plt.xlabel(pr.x_label)
plt.show()
fig.savefig(pr.saveAs)
Loading