From 46ef59697e95bdbce9257310fcef05f829a0b190 Mon Sep 17 00:00:00 2001 From: Rantaharju Jarno Date: Tue, 5 Dec 2023 21:02:18 +0200 Subject: [PATCH] Include column names in filter_location parameters --- niimpy/preprocessing/location.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/niimpy/preprocessing/location.py b/niimpy/preprocessing/location.py index 592beb39..b0964a94 100644 --- a/niimpy/preprocessing/location.py +++ b/niimpy/preprocessing/location.py @@ -79,7 +79,12 @@ def distance_matrix(lats, lons): def filter_location(location, remove_disabled=True, remove_zeros=True, - remove_network=True): + remove_network=True, + latitude_column = "double_latitude", + longitude_column = "double_longitude", + label_column = "label", + provider_column = "provider", + ): """Remove low-quality or weird location samples Parameters @@ -103,17 +108,17 @@ def filter_location(location, """ if remove_disabled: - assert 'label' in location - location = location[location['label'] != 'disabled'] + assert label_column in location + location = location[location[label_column] != 'disabled'] if remove_zeros: - index = (location["double_latitude"] ** 2 + - location["double_longitude"] ** 2) > 0.001 + index = (location[latitude_column] ** 2 + + location[longitude_column] ** 2) > 0.001 location = location[index] if remove_network: - assert 'provider' in location - location = location[location['provider'] == 'gps'] + assert provider_column in location + location = location[location[provider_column] == 'gps'] return location