diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8ab857a --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text eol=input \ No newline at end of file diff --git a/.gitignore b/.gitignore index 159ff9b..b0d3977 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,7 @@ slurm*.out *.npz # Temp Scripts -*.ipynb \ No newline at end of file +*.ipynb + +# Temp shell scripts +temp*.sh \ No newline at end of file diff --git a/configs/datamodule/herwig_event_multihadron.yaml b/configs/datamodule/herwig_event_multihadron.yaml index c4f0cc7..93e7ba5 100644 --- a/configs/datamodule/herwig_event_multihadron.yaml +++ b/configs/datamodule/herwig_event_multihadron.yaml @@ -1,19 +1,27 @@ -_target_: hadml.datamodules.gan_datamodule.EventGANDataModule -batch_size: 1 -num_workers: 1 -pin_memory: False -train_val_test_split: [5, 2, 3] -frac_data_used: -cond_dataset: - _target_: hadml.datamodules.components.herwig.HerwigEventMultiHadronDataset - root: ${paths.data_dir}Herwig/ - raw_file_list: - - "AllClusters_paper4.dat" - processed_file_name: "herwig_multihadron_graph_cond_data.pt" +_target_: hadml.datamodules.gan_datamodule.MultiHadronEventGANDataModule -obs_dataset: - _target_: hadml.datamodules.components.herwig.HerwigEventMultiHadronDataset - root: ${paths.data_dir}Herwig/ - raw_file_list: - - "AllClusters_paper4.dat" - processed_file_name: "herwig_multihadron_graph_obs_data_variation.pt" +# Basic parameters +batch_size: 32 +num_workers: 128 # You need to write the number of CPUs of your machine + # Example 1: running on Intel Core i5 having + # 1 socket x 4 physical cores x 2 threads = 8 + # Example 2: running on a Perlmutter (NERSC) GPU node + # 1 x AMD EPYC™ 7763 (CPU) = 64 cores x 2 threads = 128 +train_val_test_split: [0.7, 0.15, 0.15] + +# Raw file paths (the GAN datamodule will pass those arguments to an appropriate parser) +raw_file_list: + - "AllClusters_pions_272K.dat" + # - "AllClusters_1M.dat" + +processed_filename: "pions_only_events_272K.npy" +# processed_filename: "multihadron_events_1M.npy" + +dist_plots_filename: "dist_plots_pions_only_events_272K.pdf" +# dist_plots_filename: "dist_plots_multihadron_events_1M.pdf" + +pid_map_file: "pid_to_idx_pions_only_events_272K.pkl" +# pid_map_file: "pid_to_idx_pions_multihadron_events_1M.pkl" + +# Use it if you want to invoke data preparation without initialising your PyTorch Lightning model +initialise_data_preparation: True \ No newline at end of file diff --git a/configs/experiment/herwig_event_multihadron.yaml b/configs/experiment/herwig_event_multihadron.yaml index bd89e87..efb97e4 100644 --- a/configs/experiment/herwig_event_multihadron.yaml +++ b/configs/experiment/herwig_event_multihadron.yaml @@ -1,79 +1,37 @@ # @package _global_ -# start the environment -# conda-start torch +# ================================================================================================== +# To execute this experiment run: +# $ python src/train.py experiment=herwig_event_multihadron -# to execute this experiment run: -# python src/train.py experiment=herwig_event -## to add a logger -# python src/train.py experiment=herwig_all_hadron logger=wandb - -## with training techniques -# python src/train.py experiment=herwig_all_hadron logger=wandb +trainer.gradient_clip_val=0.5 +# You may also add a logger (e.g. Weight & Biases): +# $ python src/train.py experiment=herwig_event_multihadron logger=wandb +# ================================================================================================== +# Default configuration sources defaults: - override /datamodule: herwig_event_multihadron.yaml - - override /model: cond_event_gan.yaml - - override /callbacks: default.yaml - - override /trainer: gpu.yaml + - override /model: multihadron_event_gan.yaml + +# Hydra logs +task_name: herwigMultiHadronEvents +tags: ["herwig", "MultiHadronEvents"] -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters -task_name: herwigEvent -tags: ["herwig", "Events"] +# Training and testing invocation +train: True +test: False +# Logger logger: wandb: - project: "herwigEventsMultiHadron" - tags: ["herwig", "Events"] - name: fit_to_nominal - -seed: 12345 + # Specify the name of your W&B project + project: "herwigMultiHadronEvents" + tags: ["herwig", "MultiHadronEvents"] +# Trainer trainer: - max_epochs: 6000 - log_every_n_steps: 1 - -callbacks: - model_checkpoint: - monitor: "val/min_avg_wd" - mode: "min" - save_top_k: 5 - save_last: True - -# ## override /datamodule: -datamodule: - batch_size: 10000 - num_workers: 8 - pin_memory: True - # train_val_test_split: [200000, 30000, 30000] - train_val_test_split: [6, 2, 2] - - -# ## override /model: -model: - noise_dim: 10 - generator: - hidden_dims: [256, 256] - - discriminator: - _target_: hadml.models.components.deep_set.DeepSetModule - - optimizer_generator: - lr: 0.000001 - - optimizer_discriminator: - lr: 0.0001 - - num_critics: 1 - num_gen: 5 - -# cond_dataset: -# raw_file_list: -# - "ClusterTo2Pi0_nominal.dat" -# processed_file_name: "herwig_graph_cond_data.pt" - -# obs_dataset: -# raw_file_list: -# - "ClusterTo2Pi0_nominal_2.dat" -# processed_file_name: "herwig_graph_obs_data.pt" + min_epochs: 1 + max_epochs: 1 + num_sanity_val_steps: -1 + accelerator: gpu + devices: 1 \ No newline at end of file diff --git a/configs/model/multihadron_event_gan.yaml b/configs/model/multihadron_event_gan.yaml new file mode 100644 index 0000000..d4c4c12 --- /dev/null +++ b/configs/model/multihadron_event_gan.yaml @@ -0,0 +1,22 @@ +_target_: hadml.models.cgan.multihadron_event_gan.MultiHadronEventGANModule +noise_dim: 4 +loss_type: "wasserstein" + +generator: + _target_: hadml.models.components.encoder_transformer.Generator + pid_map_filepath: "${datamodule.data_dir}/processed/${datamodule.pid_map_file}" + noise_dim: ${model.noise_dim} + +discriminator: + _target_: hadml.models.components.encoder_transformer.Discriminator + pid_map_filepath: "${datamodule.data_dir}/processed/${datamodule.pid_map_file}" + +optimizer_generator: + _target_: torch.optim.Adam + _partial_: true + lr: 0.01 + +optimizer_discriminator: + _target_: torch.optim.Adam + _partial_: true + lr: 0.02 \ No newline at end of file diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index f3a9d96..efad430 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -14,7 +14,7 @@ devices: 1 # perform a validation loop every N training epochs check_val_every_n_epoch: 1 -# set True to to ensure deterministic results +# set True to ensure deterministic results # makes training slower but gives more reproducibility than just setting seeds deterministic: False diff --git a/data/Herwig/pid_to_idx.pkl b/data/Herwig/pid_to_idx.pkl new file mode 100644 index 0000000..0f2cb39 Binary files /dev/null and b/data/Herwig/pid_to_idx.pkl differ diff --git a/data/Herwig/pids_to_ix.pkl b/data/Herwig/pids_to_ix.pkl deleted file mode 100644 index 9588162..0000000 Binary files a/data/Herwig/pids_to_ix.pkl and /dev/null differ diff --git a/hadml/datamodules/components/herwig.py b/hadml/datamodules/components/herwig.py index 6569ed7..da0e07e 100644 --- a/hadml/datamodules/components/herwig.py +++ b/hadml/datamodules/components/herwig.py @@ -5,11 +5,11 @@ from typing import Dict, Optional, Tuple import glob import math - +import matplotlib.pyplot as plt import numpy as np import pandas as pd - import torch + from pytorch_lightning import LightningDataModule from pytorch_lightning.core.mixins import HyperparametersMixin @@ -25,10 +25,10 @@ from torch_geometric.data import Data from torch_geometric.data import InMemoryDataset - from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import Dataset -pid_map_fname = "pids_to_ix.pkl" +pid_map_fname = "pid_to_idx.pkl" class Herwig(LightningDataModule): @@ -660,114 +660,49 @@ def get_angles(four_vector): return data -class HerwigEventMultiHadronDataset(InMemoryDataset): - def __init__( - self, - root, - transform=None, - pre_transform=None, - pre_filter=None, - raw_file_list=None, - processed_file_name="herwig_graph_data.pt", - ): +class HerwigMultiHadronEventDataset(Dataset): + """ This class takes preprocessed multihadron events (i.e. clusters and multiple hadrons having + different types transformed to indices) and provide the DataLoader class object with prepared + generator and discriminator sentences. """ - self.raw_file_list = [] - for pattern in raw_file_list: - self.raw_file_list += glob.glob(os.path.join(root, "raw", pattern)) - self.raw_file_list = [ - os.path.basename(raw_file) for raw_file in self.raw_file_list - ] - self.processed_file_name = processed_file_name - - if root: - pids_map_path = os.path.join(root, pid_map_fname) - if os.path.exists(pids_map_path): - print("Loading existing pids map: ", pids_map_path) - self.pids_to_ix = pickle.load(open(pids_map_path, "rb")) - else: - raise RuntimeError("No pids map found at", pids_map_path) - - super().__init__(root, transform, pre_transform, pre_filter) - - self.data, self.slices = torch.load(self.processed_paths[0]) - - @property - def raw_file_names(self): - if self.raw_file_list is not None: - return self.raw_file_list - return ["ClusterTo2Pi0_new.dat"] - - @property - def processed_file_names(self): - return [self.processed_file_name] - - def download(self): - pass - - def process(self): - all_data = [] - for raw_path in self.raw_paths: - with open(raw_path) as f: - data_list = [self._create_data(line) for line in f if line.strip()] - - if self.pre_filter is not None: - data_list = [data for data in data_list if self.pre_filter(data)] - - if self.pre_transform is not None: - data_list = [self.pre_transform(data) for data in data_list] - - all_data += data_list - - data, slices = self.collate(all_data) - torch.save((data, slices), self.processed_paths[0]) + def __init__(self, n_had_types, clusters, hadrons_with_types, max_n_hadrons): + super().__init__() - def _create_data(self, line): - items = line.split("|")[:-1] - clusters = [pd.Series(c.split(";")[:-1]).str.split(',', expand=True) for c in items] - - q1s, q2s, cs, hadrons, cluster_labels, cs_for_hadrons = [], [], [], [], [], [] - for i, cluster in enumerate(clusters): - # select the two quarks and the cluster /heavy cluster from the cluster - q1 = cluster.iloc[0].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) - q2 = cluster.iloc[1].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) - c = cluster.iloc[2].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) - q1s.append(q1) - q2s.append(q2) - cs.append(c) - # select the final states from the cluster - hadron = cluster[cluster[2] == '[ ]'].to_numpy()[:, [1,3,4,5,6]].astype(float) - hadrons.append(hadron) - # assign cluster label to all hadrons - cluster_labels += [i] * len(hadron) - c_for_hadrons = c.repeat(len(hadron), axis=0) - cs_for_hadrons.append(c_for_hadrons) - # concatenate all clusters - q1s = np.concatenate(q1s) - q2s = np.concatenate(q2s) - cs = np.concatenate(cs) - hadrons = np.concatenate(hadrons) - cs_for_hadrons = np.concatenate(cs_for_hadrons) - cond_kin = np.concatenate([cs[:, [1,2,3,4]], q1s[:, [1,2,3,4]], q2s[:, [1,2,3,4]]], axis=1) - had_kin = np.concatenate([cs_for_hadrons[:, [1,2,3,4]], hadrons[:, [1,2,3,4]]], axis=1) - cond_kin_rest_frame = boost(cond_kin) - had_kin_rest_frame = torch.from_numpy(boost(had_kin)[:, 4:]) - had_kin = torch.from_numpy(had_kin[:, 4:]) - - q_phi, q_theta = get_angles(cond_kin_rest_frame[:, 4:8]) - q_momenta = np.stack([q_phi, q_theta], axis=1) - cond_info = np.concatenate([cs[:, [1,2,3,4]], q1s[:, :1], q2s[:, :1], q_momenta], axis=1) - cond_info = torch.from_numpy(cond_info.astype(np.float32)) + self.n_had_types = n_had_types + self.hadrons_with_types = hadrons_with_types + self.clusters = clusters + self.n_clusters = len(clusters) + self.max_n_hadrons = max_n_hadrons + # Hadron padding token (zeros + special hadron type: the first one-hot position) + self.hadron_padding_token = torch.zeros( + 1, hadrons_with_types[0].get_kinematics_dims() + n_had_types) + self.hadron_padding_token[0, hadrons_with_types[0].get_kinematics_dims()] = 1.0 - # convert particle IDs to indices - # then these indices can be embedded in N dim. space - had_type_indices = torch.from_numpy(np.vectorize(self.pids_to_ix.get)(hadrons[:, [0]].astype(np.int16))) - data = Data( - x=cond_info.float(), - had_kin_rest_frame=had_kin_rest_frame.float(), - had_kin=had_kin.float(), - had_type_indices=had_type_indices, - cluster_labels=torch.tensor(cluster_labels).int(), - edge_index=None, - ) - return data + def __len__(self): + return self.n_clusters + + + def __getitem__(self, index): + """ This method is mainly responsible for tokenisation and building sentences containing + "cluster/hadron padding tokens". It returns a pair of prepared sentences. """ + + # Preparing the input sentence. Tokens then need to be concatenated with noise. + # [[cluster_kin][cluster_kin]...[padding_token]] + # max_n_hadrons = max number of hadrons produced by the heaviest cluster + gen_input = torch.stack([self.clusters[index] for _ in range(self.max_n_hadrons)]) + + # Preparing the output sentence. The hadron type is a one-hot vector. + # [[hadron_kin, hadron_type][hadron_kin, hadron_type]...[padding_token]] + # max_n_hadrons = max number of hadrons produced by the heaviest cluster + hadrons = self.hadrons_with_types[index].kinematics + hadron_types = torch.nn.functional.one_hot( + self.hadrons_with_types[index].types, self.n_had_types) + disc_input = torch.cat([hadrons, hadron_types], dim=1) + n_hadrons = len(hadrons) + if self.max_n_hadrons - n_hadrons > 0: + hadron_padding_tokens = torch.cat([self.hadron_padding_token + for _ in range(self.max_n_hadrons - n_hadrons)]) + disc_input = torch.cat([disc_input, hadron_padding_tokens]) + + return gen_input.to(torch.float32), disc_input.to(torch.float32) \ No newline at end of file diff --git a/hadml/datamodules/components/herwig_multihadron_parser.py b/hadml/datamodules/components/herwig_multihadron_parser.py new file mode 100644 index 0000000..4c3588f --- /dev/null +++ b/hadml/datamodules/components/herwig_multihadron_parser.py @@ -0,0 +1,227 @@ +""" + +You may test the Parser class by adding the code below to your main script: + +**************************************************************************************************** +from hadml.datamodules.parsers.herwig_multihadron_parser import HerwigMultiHadronEventParser + +parser = HerwigMultiHadronEventParser( + data_dir="data/Herwig", + raw_file_list=["AllClusters_1K.dat"], + processed_filename="herwig_multihadron_events_1K.npy", + pid_map_file="pid_to_idx.pkl", + debug=True) + +parser.parse_data() +**************************************************************************************************** + +""" + +from collections import Counter +import numpy as np +import pandas as pd +import pickle, os +from hadml.datamodules.components.utils import (boost, get_angles) + + +class HerwigMultiHadronEventParser(): + """ Parser for reading and processing raw data generated by Herwig: + quarks -> heaviest cluster -> ... [light clusters] ... -> multiple hadrons. """ + + def __init__(self, + data_dir, # Data directory path. + raw_file_list, # List of raw data filenames. + processed_filename, # Processed data filename. + pid_map_file, # PID-to-MostCommonID map filename. + debug=False # Number of events being processed is + # printed when True. + ): + processed_dir = os.path.join(data_dir, "processed") + if not os.path.exists(processed_dir): + os.makedirs(processed_dir) + self.raw_file_list = [os.path.join(data_dir, "raw", f) for f in raw_file_list] + self.processed_filepath = os.path.join(data_dir, "processed", processed_filename) + self.pids_to_idx_path = os.path.join(data_dir, "processed", pid_map_file) + self.debug = debug + + def parse_data(self): + """ Parse data from raw files, prepare the PID-to-MostCommonID map + and saving processed files as self.processed_filename. """ + + # Looking for an existing processed data file + print('\n', ' '*21, self.__class__.__name__, '\n', '-'*70, sep='') + if os.path.exists(self.processed_filepath): + print("Found processed data in:\n ", self.processed_filepath) + print('-'*70) + return + + # Loading/creating the PID-to-MostCommonID map + if os.path.exists(self.pids_to_idx_path): + print("Loading the existing PID-to-MostCommonID map:", self.pids_to_idx_path) + with open(self.pids_to_idx_path, "rb") as f: + self.pids_to_idx = pickle.load(f) + else: + self.pids_to_idx = self._create_pids_to_idx_dict(self.debug) + + # Parsing events + all_data = [] + print("Parsing all events...") + for raw_path in self.raw_file_list: + with open(raw_path) as f: + print(f" --> Parsing data from file {raw_path}") + data_list = [] + for event_index, event_line in enumerate(f): + if event_line.strip(): + if self.debug: + print(f" processing event #{event_index}", end='\r') + data_list.append(self._parse_raw_event(event_line)) + if self.debug: + print() + all_data += data_list + + # Saving the processed data as NumPy binary files + processed_data = { + "cluster_kin" : [], + "had_kin" : [], + "had_kin_rest_frame" : [], + "had_type_indices" : [], + "cluster_labels" : [], + "n_had_type_indices": len(self.pids_to_idx) + } + + for d in all_data: + processed_data["cluster_kin"].append(d["cluster_kin"]) + processed_data["had_kin"].append(d["had_kin"]) + processed_data["had_kin_rest_frame"].append(d["had_kin_rest_frame"]) + processed_data["had_type_indices"].append(d["had_type_indices"]) + processed_data["cluster_labels"].append(d["cluster_labels"]) + + with open(self.processed_filepath, "wb") as f: + np.save(f, processed_data) + print("Processed data saved in:\n ", self.processed_filepath, '\n', '-'*70, sep='') + + def _parse_particles(self, event_line): + """ Parse all the particles (including quarks, heavy/light cluster and hadrons from + a single event provided as a raw event line). """ + cluster_decays = event_line.split("|")[:-1] + particles = [pd.Series(c.split(";")[:-1]).str.split(',', expand=True) + for c in cluster_decays] + return particles + + def _parse_raw_event(self, event_line): + """ Parse data presented as an event prepared in a specific format. """ + + # Parsing all particles + particles = self._parse_particles(event_line) + q1s, q2s, cs, hadrons, cluster_labels, cs_for_hadrons = [], [], [], [], [], [] + + # Processing quarks, clusters and hadrons + for i, particle in enumerate(particles): + + # Selecting the two quarks and the cluster/heavy cluster from the cluster + q1 = particle.iloc[0].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) + q2 = particle.iloc[1].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) + c = particle.iloc[2].to_numpy()[[1,3,4,5,6]].astype(float).reshape(1, -1) + q1s.append(q1) + q2s.append(q2) + cs.append(c) + + # Selecting the final states from the cluster + hadron = particle[particle[2] == '[ ]'].to_numpy()[:, [1,3,4,5,6]].astype(float) + hadrons.append(hadron) + + # Assigning cluster labels to hadrons + cluster_labels += [i] * len(hadron) + c_for_hadrons = c.repeat(len(hadron), axis=0) + cs_for_hadrons.append(c_for_hadrons) + + # Concatenating clusters + q1s = np.concatenate(q1s) + q2s = np.concatenate(q2s) + cs = np.concatenate(cs) + + # Hadrons [PID, E, px, py, pz] + hadrons = np.concatenate(hadrons) + + # Heaviest clusters [Cluster ID, E, px, py, pz] + cs_for_hadrons = np.concatenate(cs_for_hadrons) + + # Heaviest clusters + 2 quarks + # [c_E, c_px, c_py, c_pz, q1_E, q1_px, q1_py, q1_pz, q2_E, q2_px, q2_py, q2_pz] + cond_kin = np.concatenate([cs[:, [1,2,3,4]], q1s[:, [1,2,3,4]], q2s[:, [1,2,3,4]]], axis=1) + + # Heaviest cluster + hadron + # [c_E, c_px, c_py, c_pz, h_E, h_px, h_py, h_pz] + had_kin = np.concatenate([cs_for_hadrons[:, [1,2,3,4]], hadrons[:, [1,2,3,4]]], axis=1) + + # Heaviest cluster + 2 quarks in the rest frame (rf) of the former + # [c_E, c_px, c_py, c_pz, q1_Erf, q1_pxrf, q1_pyrf, q1_pzrf, + # q2_Erf, q2_pxrf, q2_pyrf, q2_pzrf] + cond_kin_rest_frame = boost(cond_kin) + + # Hadrons in the rest frame (rf) of the heaviest cluster + # [Erf, pxrf, pyrf, pzrf] + had_kin_rest_frame = boost(had_kin)[:, 4:] + + # Hadrons [E, px, py, pz] + had_kin = had_kin[:, 4:] + + # Computing angles for the two quarks via cond_kin_rest_frame, + # i.e. [cluster + q1/q2 in the cluster rest frame] + q_phi, q_theta = get_angles(cond_kin_rest_frame[:, 4:8]) + + # Computing quark momenta based on the angles + q_momenta = np.stack([q_phi, q_theta], axis=1) + + # Preparing X (cluster + quark types + quark momenta in the cluster rest frame, crf) + # [c_E, c_px, c_py, c_pz, q1_type, q2_type, q1_crf_momentum, q2_crf_momentum] + cond_info = np.concatenate([cs[:, [1,2,3,4]], q1s[:, :1], q2s[:, :1], q_momenta], axis=1) + cond_info = cond_info.astype(np.float32) + + # Mapping particle IDs to indices using the prepared PID-to-ID dictionary + try: + had_type_indices = np.vectorize(self.pids_to_idx.get)(hadrons[:, [0]].astype(np.int32)) + except Exception as X: + # Debugging what event makes the parser stop working + print("Line = ", event_line) + print("Exception: ", X) + + return { + "cluster_kin" : cond_info, + "had_kin" : had_kin, + "had_kin_rest_frame": had_kin_rest_frame, + "had_type_indices" : had_type_indices, + "cluster_labels" : cluster_labels + } + + def _create_pids_to_idx_dict(self, debug=False): + """ Create a new PID-to-MostCommonID map/dictionary """ + + print("Creating a new PID-to-MostCommonID map/dictionary...") + hadron_types = [] + + for raw_path in self.raw_file_list: + with open(raw_path) as f: + print(f" --> Analysing file {raw_path}") + for event_index, event_line in enumerate(f): + if event_line.strip(): + if debug: + print(f" processing event #{event_index}", end='\r') + particles = self._parse_particles(event_line) + for particle in particles: + PIDs = particle[particle[2] == '[ ]'].to_numpy()[:, [1]].astype(float) + for id in PIDs: + hadron_types.append(id[0]) + if debug: + print() + + count = Counter(hadron_types) + hadron_pids = list(map(lambda x: x[0], count.most_common())) + pids_to_idx = {pids: i for i, pids in enumerate(hadron_pids)} + + with open(self.pids_to_idx_path, "wb") as f: + pickle.dump(pids_to_idx, f) + print("The PID-to-MostCommonID map has been successfully saved in:", + "\n ", self.pids_to_idx_path) + + return pids_to_idx diff --git a/hadml/datamodules/components/utils.py b/hadml/datamodules/components/utils.py index e9e4873..8d98e26 100644 --- a/hadml/datamodules/components/utils.py +++ b/hadml/datamodules/components/utils.py @@ -2,7 +2,7 @@ import pickle from functools import reduce from typing import Tuple, List, Optional - +import warnings import numpy as np import pandas as pd @@ -70,7 +70,13 @@ def create_boost_fn(cluster_4vec: np.ndarray): velocity = p0 / gamma.reshape(-1, 1) / mass.reshape(-1, 1) del mass, p0 v_mag = np.sqrt((velocity**2).sum(axis=1)) - n = velocity / v_mag.reshape(-1, 1) + + # The RuntimeWarning related to NaN values in the "velocity" or "v_mag" arrays can be safely + # ignored, as the line n[np.isnan(n)] below replaces those vectors with zero vectors. + # Ignoring the warning prevents overhead caused by frequent printing to the standard output. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + n = velocity / v_mag.reshape(-1, 1) n[np.isnan(n)] = 0 del velocity diff --git a/hadml/datamodules/gan_datamodule.py b/hadml/datamodules/gan_datamodule.py index 03b03bd..756d2fc 100644 --- a/hadml/datamodules/gan_datamodule.py +++ b/hadml/datamodules/gan_datamodule.py @@ -1,14 +1,15 @@ from typing import Any, Dict, Optional, Tuple, Protocol - -import torch +import torch, os, numpy as np from pytorch_lightning import LightningDataModule from pytorch_lightning.trainer.supporters import CombinedLoader from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split - from torch_geometric.loader import DataLoader as GeometricDataLoader from torch_geometric.data.dataset import Dataset as GeometricDataset - from hadml.datamodules.components.utils import process_data_split, get_num_asked_events +from hadml.datamodules.components.herwig_multihadron_parser import HerwigMultiHadronEventParser +from hadml.datamodules.components.herwig import HerwigMultiHadronEventDataset +import matplotlib.pyplot as plt +from dataclasses import dataclass class GANDataProtocol(Protocol): @@ -251,3 +252,224 @@ def test_dataloader(self): ), } return CombinedLoader(loaders, mode="max_size_cycle") + + +class MultiHadronEventGANDataModule(LightningDataModule): + def __init__( + self, + data_dir="data/Herwig", + raw_file_list=["AllClusters_10K.dat"], + processed_filename="herwig_multihadron_events_10K.npy", + dist_plots_filename="distribution_plots_10K.pdf", + pid_map_file="pid_to_idx.pkl", + train_val_test_split: Tuple[float, float, float] = (0.7, 0.15, 0.15), + batch_size=32, + num_workers=8, + initialise_data_preparation=False, + debug=True + ): + super().__init__() + + self.data_dir = data_dir + processed_path = os.path.join(os.path.normpath(data_dir), "processed") + if not os.path.exists(processed_path): + os.makedirs(processed_path) + self.processed_filename = os.path.join(processed_path, processed_filename) + + dist_plots_path = os.path.join(os.path.normpath(self.data_dir), "plots") + if not os.path.exists(dist_plots_path): + os.makedirs(dist_plots_path) + self.dist_plots_path = os.path.join(dist_plots_path, dist_plots_filename) + + self.train_val_test_split = train_val_test_split + self.batch_size = batch_size + self.num_workers = num_workers + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + parser = HerwigMultiHadronEventParser( + data_dir=data_dir, + raw_file_list=raw_file_list, + processed_filename=processed_filename, + pid_map_file=pid_map_file, + debug=debug + ) + parser.parse_data() + + print('\n', ' '*20, self.__class__.__name__,'\n', '-'*70, sep='') + + if initialise_data_preparation: + self.prepare_data() + self.setup() + one_train_batch = next(iter(self.train_dataloader())) + gen_input, disc_input = one_train_batch + print(f"\nGenerator input batch: {len(gen_input)} ([{gen_input.size()}])") + print(f"Discriminator input batch: {len(disc_input)} ([{disc_input.size()}])") + print("\nGenerator input sample:\n", gen_input[0]) + print("\nDiscriminator input sample:\n", disc_input[0]) + + def prepare_data(self): + # Load data prepared by the parser + with open(self.processed_filename, "rb") as f: + data = np.load(f, allow_pickle=True) + + cluster_kin = data.item()["cluster_kin"] + hadron_kin = data.item()["had_kin"] + had_type_indices = data.item()["had_type_indices"] + cluster_labels = data.item()["cluster_labels"] + n_events = len(cluster_kin) + self.n_had_types = data.item()["n_had_type_indices"] + 1 # 1 extra type for a stop/padding token + hadron_kin_rest_frame = data.item()["had_kin_rest_frame"] + + # Assigning hadrons to clusters + self.clusters, self.hadrons_with_types, n = self._get_hadrons_and_clusters__( + n_events, cluster_kin, cluster_labels, hadron_kin_rest_frame, had_type_indices) + n_clusters_extracted_from_events = n + n_hadrons_per_cluster = [len(hadron_seq.types) for hadron_seq in self.hadrons_with_types] + self.max_n_hadrons = max(n_hadrons_per_cluster) + n_hadrons_per_event = [len(d) for d in hadron_kin] + + # Preparing distribution plots + if not os.path.exists(self.dist_plots_path): + hadron_energy = np.concatenate([d for d in hadron_kin])[:, 0] + cluster_energy = np.concatenate([d for d in cluster_kin])[:, 0] + self._plot_dist( + filepath=self.dist_plots_path, + data=[[n_hadrons_per_event, n_hadrons_per_cluster], [hadron_energy, cluster_energy]], + xlabels=[["Number of hadrons", "Number of hadrons"], + ["Energy [GeV]", "Energy [GeV]"]], + ylabels=[["Events", "Clusters"], + ["Hadrons", "Clusters"]], + legend_labels=[["Hadron Multiplicity\nDistribution per Event", + "Hadron Multiplicity\nDistribution per Cluster"], + ["Hadron Energy Distribution", "Cluster Energy Distribution"]]) + + # Printing general statistics about events, clusters and hadrons + print("Initial number of events:", n_events) + print("Total number of clusters:", n_clusters_extracted_from_events) + print("Total number of hadron types (with a stop token):", self.n_had_types) + print("Largest number of hadrons per cluster:", self.max_n_hadrons) + print("Largest number of hadrons per event:", max(n_hadrons_per_event)) + + def _get_hadrons_and_clusters__(self, n_events, cluster_kin, cluster_labels, hadron_kin, + had_type_indices): + clusters, hadrons_with_types = [], [] + n_clusters_extracted_from_events = 0 + + for i in range(n_events): + cluster_in_event = [] # [cluster ... cluster] + hadrons_in_event = [] # [HadronsWithTypes ... HadronsWithTypes] + + for j, cluster in enumerate(cluster_kin[i]): + # Extracting all clusters from a single event + cluster_idx_mask = [cl == j for cl in cluster_labels[i]] + cluster = torch.from_numpy(cluster) + cluster_in_event.append(cluster) + n_clusters_extracted_from_events += 1 + # Extracting all hadrons from a single event + hadrons = torch.tensor(hadron_kin[i][cluster_idx_mask]) + # Shifting the hadron type for an additional type assigned to a stop token (= 0) + hadron_types = torch.tensor(had_type_indices[i][cluster_idx_mask]).squeeze(1) + 1 + hadrons_in_event.append(HadronsWithTypes(hadrons, hadron_types)) + + clusters += cluster_in_event + hadrons_with_types += hadrons_in_event + + return clusters, hadrons_with_types, n_clusters_extracted_from_events + + def setup(self, stage: Optional[str] = None): + if not self.data_train and not self.data_val and not self.data_test: + # Passing clusters and hadrons to a dataset reponsible for tokenisation + dataset = HerwigMultiHadronEventDataset( + self.n_had_types, + self.clusters, + self.hadrons_with_types, + self.max_n_hadrons) + + # Creating the training, validation and test datasets + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + + print(f"Number of training examples: {len(self.data_train)}") + print(f"Number of validation examples: {len(self.data_val)}") + print(f"Number of test examples: {len(self.data_test)}") + print('-'*70) + + def train_dataloader(self): + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + generator=torch.Generator().manual_seed(42) + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size, + num_workers=self.num_workers, + generator=torch.Generator().manual_seed(42) + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size, + num_workers=self.num_workers, + generator=torch.Generator().manual_seed(42) + ) + + def _plot_dist(self, filepath, data, xlabels, ylabels=None, legend_labels=None, labels=None): + """ Draw distribution diagrams for a list of three data sets """ + plt.clf() + _, ax = plt.subplots(2, 2, figsize=(13, 8)) + colors = [["black", "black"], ["orange", "orange"]] + + for r in range(len(data)): + for c in range(len(data)): + samples = data[r][c] + # Setting the appropriate bin range + if legend_labels[r][c].startswith("Hadron Multiplicity"): + sample_range = [1, max(samples)] + bins = np.linspace( + start=sample_range[0] - 0.5, + stop=sample_range[1] + 0.5, + num=sample_range[1] - sample_range[0] + 2, + retstep=0.5)[0] + else: + bins = "scott" + + # Preparing a chart + ax[r][c].hist(samples, bins=bins, color=colors[r][c], alpha=0.7, + label=legend_labels[r][c]) + ax[r][c].set_xlabel(xlabels[r][c]) + if ylabels[r][c] is not None: + ax[r][c].set_ylabel(ylabels[r][c]) + ax[r][c].legend(loc='upper right') + + # Setting the ticks along OX if needed + if legend_labels[r][c].startswith("Hadron Multiplicity"): + density = 2 + xticks = np.arange(start=sample_range[0] - 1, stop=sample_range[1] + 1, + step=density)[1:] + ax[r][c].set_xticks(xticks) + + plt.tight_layout() + plt.savefig(filepath) + print("Distribution diagrams have been saved in\n ", filepath) + + +@dataclass +class HadronsWithTypes: + """Class holding hadrons and their types/indices, + assuming all originate from the same heavy cluster.""" + kinematics: torch.Tensor # [h_kin ... h_kin] + types: torch.Tensor # [h_id ... h_id] + + def get_kinematics_dims(self) -> int: + return len(self.kinematics[0]) \ No newline at end of file diff --git a/hadml/metrics/compare_fn.py b/hadml/metrics/compare_fn.py index 9fc201c..fab287a 100644 --- a/hadml/metrics/compare_fn.py +++ b/hadml/metrics/compare_fn.py @@ -1,14 +1,13 @@ import math import os from typing import List, Tuple, Optional, Any, Dict - from matplotlib import ticker from pytorch_lightning.core.mixins import HyperparametersMixin - import numpy as np import matplotlib.pyplot as plt - from .image_converter import fig_to_array +from torch import Tensor + def create_plots(nrows, ncols): fig, axs = plt.subplots( @@ -262,4 +261,4 @@ def __call__( out_images["particle kinematics"] = fig_to_array(fig) plt.close("all") - return out_images + return out_images \ No newline at end of file diff --git a/hadml/metrics/image_converter.py b/hadml/metrics/image_converter.py index a4f04bd..ba5d1de 100644 --- a/hadml/metrics/image_converter.py +++ b/hadml/metrics/image_converter.py @@ -3,9 +3,10 @@ from matplotlib.figure import Figure -def fig_to_array(fig: Figure) -> np.ndarray: +def fig_to_array(fig: Figure, tight_layout=True) -> np.ndarray: """Convert a matplotlib figure to a numpy array.""" - fig.tight_layout(pad=0) + if tight_layout: + fig.tight_layout(pad=0) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) return data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) diff --git a/hadml/models/cgan/multihadron_event_gan.py b/hadml/models/cgan/multihadron_event_gan.py new file mode 100644 index 0000000..d6fa144 --- /dev/null +++ b/hadml/models/cgan/multihadron_event_gan.py @@ -0,0 +1,305 @@ +import torch +from torch.optim import Optimizer +from torchmetrics import MeanMetric +from pytorch_lightning import LightningModule +from utils.utils import conditional_cat +from metrics.media_logger import log_images +from metrics.image_converter import fig_to_array +from collections import Counter +import numpy as np, matplotlib.pyplot as plt + + +class MultiHadronEventGANModule(LightningModule): + def __init__( + self, + datamodule: torch.nn.Module, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + optimizer_generator: Optimizer, + optimizer_discriminator: Optimizer, + noise_dim: int, + loss_type: str + ): + super().__init__() + self.save_hyperparameters(ignore=["generator", "discriminator"]) + self.generator = generator + self.discriminator = discriminator + self.train_gen_loss = MeanMetric() + self.train_disc_loss = MeanMetric() + + def forward(self, clusters): + generator_input = conditional_cat(clusters, self._generate_noise(len(clusters))) + generated_hadrons = self.hparams.generator(generator_input) + return generated_hadrons + + def setup(self, stage=None): + pass + + def training_step(self, batch, batch_idx, optimizer_idx): + gen_input, real_hadrons = batch + noise = torch.randn(*gen_input.size()[:2], self.hparams.noise_dim) + fake_hadrons = self.generator(noise.to(gen_input.device), gen_input) + score_for_fake = self.discriminator(fake_hadrons) + if optimizer_idx == 0: + # Training the generator + generator_loss = self._generator_loss(score_for_fake) + self.train_gen_loss(generator_loss) + self.log("generator_loss", generator_loss, prog_bar=True) + loss = generator_loss + else: + # Training the discriminator + score_for_real = self.discriminator(real_hadrons) + discriminator_loss = self._discriminator_loss(score_for_real, score_for_fake) + # TODO: grad_penalty + # ... + self.log("discriminator_loss", discriminator_loss, prog_bar=True) + loss = discriminator_loss + return {"loss": loss} + + def _generator_loss(self, score): + loss_type = self.hparams.loss_type + if loss_type == "wasserstein": + loss_gen = -score.mean(0).view(1) + elif loss_type == "bce": + loss_gen = torch.nn.functional.binary_cross_entropy_with_logits( + score, torch.ones_like(score)) + elif loss_type == "ls": + loss_gen = 0.5 * ((score - 1) ** 2).mean(0).view(1) + return loss_gen + + def _discriminator_loss(self, score_for_real, score_for_fake): + loss_type = self.hparams.loss_type + if loss_type == "wasserstein": + loss_disc = score_for_fake.mean(0).view(1) - score_for_real.mean(0).view(1) + elif loss_type == "bce": + loss_disc = torch.nn.functional.binary_cross_entropy_with_logits( + score_for_real, torch.ones_like(score_for_real)) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + score_for_fake, torch.zeros_like(score_for_fake)) + elif loss_type == "ls": + loss_disc = 0.5 * ((score_for_real - 1)**2).mean(0).view(1) + \ + 0.5 * (score_for_fake**2).mean(0).view(1) + return loss_disc + + def validation_step(self, batch, batch_idx): + gen_input, disc_input = batch + device = gen_input.device + if self.trainer.state.stage == "validate": + batch_size = len(disc_input) + n_types = len(disc_input[0, 0, 4:]) + max_n_hadrons = len(disc_input[0]) + true_types = torch.stack([torch.argmax(d[:, 4:], dim=1) for d in disc_input]) + # ========================================================================================== + # Simulating model predictions on validation data by providing some noise to the true values + # ========================================================================================== + distortion_noise = torch.randint(0, n_types, (batch_size, max_n_hadrons)) + distorted_types = (true_types + distortion_noise.to(device)) % n_types + distorted_disc_input = torch.stack([ + torch.column_stack([ + d[:, :4] * torch.rand(1).to(device), # scaling kinematics (Gaussian noise) + torch.nn.functional.one_hot(d_type, # randomly shifted types + num_classes=n_types) + ]) for d, d_type in zip(disc_input, distorted_types) + ]) + # ========================================================================================== + return {"distorted_disc_input": distorted_disc_input, "disc_input": disc_input} + + elif self.trainer.state.stage == "sanity_check": + return {"gen_input": gen_input[:, 0, :]} + + def test_step(self, batch, batch_idx): + pass + + def configure_optimizers(self): + generator_opt = self.hparams.optimizer_generator(params=self.generator.parameters()) + disriminator_opt = self.hparams.optimizer_discriminator(params=self.discriminator.parameters()) + return generator_opt, disriminator_opt + + def _generate_noise(self, n_tokens): + return torch.randn(n_tokens, self.hparams.noise_dim) + + def _compare(self, predictions, truths): + images = self._prepare_plots(predictions, truths) + # Attributes self.logger and self.logger.experiment are defined by the logger passed + # to the trainer which in turn uses an object of this model class: + if self.logger and self.logger.experiment is not None: + log_images(logger=self.logger, key="MultiHadronEvent GAN", + images=list(images.values()), caption=list(images.keys())) + + def validation_epoch_end(self, validation_step_outputs): + if self.trainer.state.stage == "validate": + # Handling the validation output list + # Shape of validation_step_outputs: [n_batches, dict_key, batch_size, max_n_hadrons, features] + sentence_stats = {} + + preds = [d["distorted_disc_input"] for d in validation_step_outputs] + # [n_hadron_sets, max_n_hadrons, features] + preds = [d for pred in preds for d in pred] + sentence_stats["pred_n_hads_per_cluster"] = [len(d[d[:, 4] == 0.0]) for d in preds] + sentence_stats["pred_n_pad_hads_per_cluster"] = [len(preds[0]) - n for n in + sentence_stats["pred_n_hads_per_cluster"]] + # [total_n_hadrons, features] + preds = [d for pred in preds for d in pred] + preds = torch.stack(preds) + + truths = [d["disc_input"] for d in validation_step_outputs] + # [n_hadron_sets, max_n_hadrons, features] + truths = [d for truth in truths for d in truth] + sentence_stats["true_n_hads_per_cluster"] = [len(d[d[:, 4] == 0.0]) for d in truths] + sentence_stats["true_n_pad_hads_per_cluster"] = [len(truths[0]) - n for n in + sentence_stats["true_n_hads_per_cluster"]] + # [total_n_hadrons, features] + truths = [d for truth in truths for d in truth] + truths = torch.stack(truths) + + # Preparing diagrams + images = self._prepare_plots(predictions=preds.cpu(), truths=truths.cpu(), + sentence_stats=sentence_stats) + + elif self.trainer.state.stage == "sanity_check": + gen_input = [d["gen_input"] for d in validation_step_outputs] + gen_input = [d for gen_in in gen_input for d in gen_in] # [total_n_clusters, features] + gen_input = torch.stack(gen_input) + + # Preparing diagrams + images = self._prepare_plots(clusters=gen_input.cpu()) + + # Sending the diagrams to the logger + if self.logger and self.logger.experiment is not None: + log_images( + self.logger, + "MultiHadronEvent GAN", + images=list(images.values()), + caption=list(images.keys()), + ) + + def _prepare_plots(self, predictions=None, truths=None, sentence_stats=None, clusters=None): + """ Prepare histograms and other charts using the data received from validation_epoch_end(). + Diagrams for the sanity check (clusters) are prepared once only before training. All the + other ones (hadrons) are drawn each time validation_epoch_end() is called. """ + diagrams = {} + + if predictions is not None and truths is not None: + preds_kin, preds_types = predictions[:, :4], torch.argmax(predictions[:, 4:], dim=1) - 1 + truths_kin, truths_types = truths[:, :4], torch.argmax(truths[:, 4:], dim=1) - 1 + + # Hadron type histogram + sample_range = [1, preds_types.max()] + bins = np.linspace( + start=sample_range[0] - 0.5, + stop=sample_range[1] + 0.5, + num=sample_range[1] - sample_range[0] + 2, + retstep=0.5)[0] + n_types = truths_types.max() + 1 + density = n_types // 25 if n_types // 25 > 0 else 1 + fig = plt.figure(figsize=(9, 6)) + plt.title("Hadron Type Distribution") + plt.hist(truths_types, bins=bins, color="black", histtype="step", label="True") + plt.hist(preds_types, bins=bins, color="#AABA9E", label="Generated") + plt.ylabel("Hadrons") + plt.xlabel("Hadron Most Common ID\n(mapped from PIDs)", labelpad=20) + xticks = np.arange(start=sample_range[0] - 1, stop=sample_range[1] + 1, step=density)[1:] + plt.xticks(xticks, rotation=90) + plt.legend() + plt.tight_layout() + diagrams["hadron_type_hist"] = fig_to_array(fig, tight_layout=False) + + # Hadron energy and momentum histogram + fig, axs = plt.subplots(2, 2, figsize=(12, 9)) + fig.subplots_adjust(wspace=0.2, hspace=0.35) + axs[0][0].set_title("Hadron Energy Distribution") + labels = ["Generated", "True"] + (_, bins, _) = axs[0][0].hist(truths_kin[:, 0], bins="scott", color="black", + label=labels[1], histtype="step") + axs[0][0].hist(preds_kin[:, 0], bins=bins, color="#AEC5EB", label=labels[0]) + axs[0][0].set_xlabel("Energy (Cluster Rest Frame)") + axis = ['x', 'y', 'z'] + for row in range(0, 2): + for col in range(0, 2): + if row == 0 and col == 0: + continue + feature = row + col + 1 + axs[row][col].set_xlabel(f"Momentum ({axis[feature - 1].capitalize()})") + axs[row][col].title.set_text("Hadron Momentum Distribution") + (_, bins, _) = axs[row][col].hist(truths_kin[:, feature], bins="auto", + color="black", label=labels[1], histtype="step") + axs[row][col].hist(preds_kin[:, feature], bins=bins, color="#F9DEC9", + label=labels[0]) + for row in range(0, 2): + for col in range(0, 2): + axs[row][col].set_ylabel("Hadrons (Log Scale)") + axs[row][col].set_yscale("log") + axs[row][col].legend(loc='upper right') + fig.suptitle("Hadron Kinematics Distribution") + diagrams["hadron_kinematics_hist"] = fig_to_array(fig, tight_layout=False) + + # Hadron and padding token multiplicity + n_max_hads = sentence_stats["true_n_hads_per_cluster"][0] + \ + sentence_stats["true_n_pad_hads_per_cluster"][0] + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + bins = np.linspace(start=-0.5, stop=n_max_hads+0.5, num=n_max_hads+2, retstep=0.5)[0] + datatype = ["true", "pred"] + labels = ["True", "Generated"] + colours = ["black", "#DE3C4B"] + for col in range(0, 2): + for i in range(0, 2): + if col == 0: + axs[col].hist(sentence_stats[f"{datatype[i]}_n_hads_per_cluster"], + bins=bins, color=colours[i], label=labels[i], rwidth=0.9) + axs[col].set_xlabel("Number of hadrons") + else: + axs[col].hist(sentence_stats[f"{datatype[i]}_n_pad_hads_per_cluster"], + bins=bins, color=colours[i], label=labels[i], rwidth=0.9) + axs[col].set_xlabel("Number of padding tokens") + axs[col].legend() + axs[col].set_ylabel("Sentences") + fig.suptitle("Sentence Statistics") + plt.tight_layout() + diagrams["sentence_statistics_hist"] = fig_to_array(fig, tight_layout=False) + + elif clusters is not None: + diagrams = {} + kinematics = clusters[:, :4] + quark_types = clusters[:, 4:6] + quark_angles = clusters[:, 6:] + + # Cluster kinematics histograms + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + axis = ['x', 'y', 'z'] + for col in range(0, 3): + axs[col].hist(kinematics[:, col], bins="auto", color="#AB92BF") + axs[col].set_xlabel(f"Momentum ({axis[col].capitalize()})") + axs[col].set_ylabel("Clusters") + fig.suptitle("Cluster Momentum Distribution") + plt.tight_layout() + diagrams["cluster_kinematics_hist"] = fig_to_array(fig, tight_layout=False) + + # Quark types and angles + count = Counter(quark_types.flatten().tolist()) + quark_pids = list(map(lambda x: x[0], count.most_common())) + pids_to_idx = {pids: i for i, pids in enumerate(quark_pids)} + n_idx = len(pids_to_idx) + bins = np.linspace(start=-0.5, stop=n_idx+0.5, num=n_idx+2, retstep=0.5)[0] + fig, axs = plt.subplots(2, 2, figsize=(9, 9)) + angles = ["phi", "theta"] + for row in range(0, 2): + for col in range(0, 2): + if row == 0: + quark_idx = [pids_to_idx[t.item()] for t in quark_types[:, col]] + axs[row][col].hist(quark_idx, bins=bins, rwidth=0.8, color="#4C1E4F") + axs[row][col].set_xlabel("PID") + axs[row][col].title.set_text("Type") + axs[row][col].set_xticks([int(pid) for pid in pids_to_idx.values()]) + axs[row][col].set_xticklabels([int(pid) for pid in pids_to_idx.keys()], + rotation=90) + else: + axs[row][col].hist(quark_angles[:, col], bins="scott", rwidth=0.8, + color="#B5A886") + axs[row][col].set_xlabel("Angle") + axs[row][col].title.set_text(f"Kinematics ({angles[col]})") + axs[row][col].set_ylabel("Quarks") + fig.suptitle("Quark Type and Momentum Distribution") + plt.tight_layout() + diagrams["quarks_features_hist"] = fig_to_array(fig, tight_layout=False) + + return diagrams \ No newline at end of file diff --git a/hadml/models/components/encoder_transformer.py b/hadml/models/components/encoder_transformer.py new file mode 100644 index 0000000..a138850 --- /dev/null +++ b/hadml/models/components/encoder_transformer.py @@ -0,0 +1,66 @@ +import torch, os, pickle + + +class Generator(torch.nn.Module): + """ Generator implemented as a encoder-only transformer model """ + + def __init__( + self, + noise_dim=4, # Arbitrary number (noise dimensionality) + cluster_kins_dim=8, # Cluster four-momentum, two quark types, phi, theta + n_quarks=2, # Number of quark types in cluster_kins_dim + quark_types=16, # Quark PIDs: (+/-) 1-8 -> transformation -> 0-16 + hadron_kins_dim=4, # Hadron four-momentum + embedding_dim=128, # Arbitrary number (but the same for the discriminator) + quark_embedding_dim=4, # Arbitrary number (quark embedding dimensionality) + n_heads=8, # Encoder architecture hyperparameter + pid_map_filepath=None # For getting information about the number of hadron most common IDs + ): + super().__init__() + with open(os.path.normpath(pid_map_filepath), "rb") as f: + n_hadron_types = len(pickle.load(f)) + 1 + self.quark_type_embedding_layer = torch.nn.Embedding(quark_types, quark_embedding_dim) + self.input_embedding_layer = torch.nn.Linear( + noise_dim + cluster_kins_dim - n_quarks + n_quarks * quark_embedding_dim, embedding_dim + ) + self.output_embedding_layer = torch.nn.Linear(embedding_dim, hadron_kins_dim + n_hadron_types) + encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=n_heads) + self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6) + + def forward(self, noise, cluster_kins): + quark_types = cluster_kins[:, :, 4:6].to(torch.int32) + quark_types = torch.where(quark_types < 0, torch.abs(quark_types), quark_types + 8) + embedded_quark_types = self.quark_type_embedding_layer(quark_types) + embedded_quark_types = embedded_quark_types.reshape(*embedded_quark_types.size()[:2], -1) + cluster_kins = torch.concatenate((cluster_kins[:, :, :4], embedded_quark_types, + cluster_kins[:, :, 6:]), dim=2) + clusters_and_noise = torch.concatenate((cluster_kins, noise), dim=2) + embedded_input = self.input_embedding_layer(clusters_and_noise) + embedded_output = self.transformer_encoder(embedded_input) + hadrons = self.output_embedding_layer(embedded_output) + return hadrons + + +class Discriminator(torch.nn.Module): + """ Discriminator implemented as a encoder-only transformer model """ + + def __init__( + self, + hadron_kins_dim=4, # Hadron four-momentum + embedding_dim=128, # Arbitrary number (but the same for the discriminator) + n_heads=8, # Encoder architecture hyperparameter + pid_map_filepath=None # For getting information about the number of hadron most common IDs + ): + super().__init__() + with open(os.path.normpath(pid_map_filepath), "rb") as f: + n_hadron_types = len(pickle.load(f)) + 1 + self.input_embedding_layer = torch.nn.Linear(hadron_kins_dim + n_hadron_types, embedding_dim) + self.output_embedding_layer = torch.nn.Linear(embedding_dim, 1) + encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=n_heads) + self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6) + + def forward(self, hadrons): + embedded_input = self.input_embedding_layer(hadrons) + embedded_output = self.transformer_encoder(embedded_input) + real_or_fake_response = self.output_embedding_layer(embedded_output) + return real_or_fake_response.mean(dim=1) \ No newline at end of file