diff --git a/tensorflow_datasets/structured/higgs.py b/tensorflow_datasets/structured/higgs.py index 444c83ace8e..9c4a4d02203 100644 --- a/tensorflow_datasets/structured/higgs.py +++ b/tensorflow_datasets/structured/higgs.py @@ -18,6 +18,7 @@ from __future__ import annotations import csv +import logging from etils import epath import numpy as np @@ -130,39 +131,15 @@ def _generate_examples(self, file_path): The features, per row. """ - fieldnames = [ - 'class_label', - 'lepton_pT', - 'lepton_eta', - 'lepton_phi', - 'missing_energy_magnitude', - 'missing_energy_phi', - 'jet_1_pt', - 'jet_1_eta', - 'jet_1_phi', - 'jet_1_b-tag', - 'jet_2_pt', - 'jet_2_eta', - 'jet_2_phi', - 'jet_2_b-tag', - 'jet_3_pt', - 'jet_3_eta', - 'jet_3_phi', - 'jet_3_b-tag', - 'jet_4_pt', - 'jet_4_eta', - 'jet_4_phi', - 'jet_4_b-tag', - 'm_jj', - 'm_jjj', - 'm_lv', - 'm_jlv', - 'm_bb', - 'm_wbb', - 'm_wwbb', - ] - + features = self.info.features + num_missing_values = 0 with epath.Path(file_path).open() as csvfile: - reader = csv.DictReader(csvfile, fieldnames=fieldnames) + reader = csv.DictReader(csvfile, fieldnames=features.keys()) for i, row in enumerate(reader): + for key, value in row.items(): + if value == '': # pylint: disable=g-explicit-bool-comparison + logging.warning('Skipping row after missing value for key=%s', key) + continue yield i, row + plural = 's' if num_missing_values else '' + logging.warning('Found %d missing value%s.', num_missing_values, plural)