Skip to content

Commit

Permalink
Handle missing values in Higgs dataset.
Browse files Browse the repository at this point in the history
Fixes issue #5428.

PiperOrigin-RevId: 639031084
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed May 31, 2024
1 parent 6bd6e8b commit ec4c6d8
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions tensorflow_datasets/structured/higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import csv
import logging

from etils import epath
import numpy as np
Expand Down Expand Up @@ -130,39 +131,15 @@ def _generate_examples(self, file_path):
The features, per row.
"""

fieldnames = [
'class_label',
'lepton_pT',
'lepton_eta',
'lepton_phi',
'missing_energy_magnitude',
'missing_energy_phi',
'jet_1_pt',
'jet_1_eta',
'jet_1_phi',
'jet_1_b-tag',
'jet_2_pt',
'jet_2_eta',
'jet_2_phi',
'jet_2_b-tag',
'jet_3_pt',
'jet_3_eta',
'jet_3_phi',
'jet_3_b-tag',
'jet_4_pt',
'jet_4_eta',
'jet_4_phi',
'jet_4_b-tag',
'm_jj',
'm_jjj',
'm_lv',
'm_jlv',
'm_bb',
'm_wbb',
'm_wwbb',
]

features = self.info.features
num_missing_values = 0
with epath.Path(file_path).open() as csvfile:
reader = csv.DictReader(csvfile, fieldnames=fieldnames)
reader = csv.DictReader(csvfile, fieldnames=features.keys())
for i, row in enumerate(reader):
for key, value in row.items():
if value == '': # pylint: disable=g-explicit-bool-comparison
logging.warning('Skipping row after missing value for key=%s', key)
continue
yield i, row
plural = 's' if num_missing_values else ''
logging.warning('Found %d missing value%s.', num_missing_values, plural)

0 comments on commit ec4c6d8

Please sign in to comment.