Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take split_dataset out from fit #42

Merged
merged 2 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ def split_dataset(
return X_train, X_test, Y_train, Y_test


def fit(X: npt.NDArray[np.bool_],
Y: npt.NDArray[np.bool_],
def fit(X_train: npt.NDArray[np.bool_],
Y_train: npt.NDArray[np.bool_],
X_test: npt.NDArray[np.bool_],
Y_test: npt.NDArray[np.bool_],
features: typing.List[str],
iters: int,
weights_filename: str,
Expand All @@ -128,8 +130,10 @@ def fit(X: npt.NDArray[np.bool_],
"""Trains an AdaBoost classifier.

Args:
X (numpy.ndarray): Training entries.
Y (numpy.ndarray): Training labels.
X_train (numpy.ndarray): Training entries.
Y_train (numpy.ndarray): Training labels.
X_test (numpy.ndarray): Testing entries.
Y_test (numpy.ndarray): Testing labels.
features (List[str]): Features, which correspond to the columns of entries.
iters (int): A number of training iterations.
weights_filename (str): A file path to write the learned weights.
Expand All @@ -147,7 +151,16 @@ def fit(X: npt.NDArray[np.bool_],
print('Outputting learned weights to %s ...' % (weights_filename))

phis: typing.Dict[int, float] = dict()
X_train, X_test, Y_train, Y_test = split_dataset(X, Y)

assert (X_train.shape[1] == X_test.shape[1]
), 'Training and test entries should have the same number of features.'
assert (X_train.shape[1] - 1 == len(features)
), 'The training data should have the same number of features + BIAS.'
assert (X_train.shape[0] == Y_train.shape[0]
), 'Training entries and labels should have the same number of items.'
assert (X_test.shape[0] == Y_test.shape[0]
), 'Testing entries and labels should have the same number of items.'

N_train, M_train = X_train.shape
w = np.ones(N_train) / N_train

Expand Down Expand Up @@ -227,7 +240,9 @@ def main() -> None:
chunk_size = int(args.chunk_size) if args.chunk_size is not None else None

X, Y, features = preprocess(train_data_filename, feature_thres)
fit(X, Y, features, iterations, weights_filename, log_filename, chunk_size)
X_train, X_test, Y_train, Y_test = split_dataset(X, Y)
fit(X_train, Y_train, X_test, Y_test, features, iterations, weights_filename,
log_filename, chunk_size)

print('Training done. Export the model by passing %s to build_model.py' %
(weights_filename))
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ console_scripts =
based_on_style = yapf

[flake8]
# E501: line too long
# E124: closing bracket does not match visual indentation
# E126: over-indentation
# E501: line too long
# BLK100: black formattable
ignore = E126,E501,BLK100
ignore = E124,E126,E501,BLK100
indent-size = 2

[mypy]
Expand Down
19 changes: 10 additions & 9 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self) -> None:
with open(ENTRIES_FILE_PATH, 'w') as f:
f.write((
' 1\tA\tC\n' # the first column represents the label (-1 / 1).
'-1\tA\tB\n' # the rest cols represents the associated features.
'-1\tA\tB\n' # the rest columns represents the associated features.
' 1\tA\tC\n'
'-1\tA\n'
' 1\tA\tC\n'))
Expand Down Expand Up @@ -110,13 +110,14 @@ def test_preprocess(self) -> None:
],
'X should represent the filtered entry features with a bias column.')

self.assertListEqual(Y.tolist(), [
True,
False,
True,
False,
True,
], 'Y should represent the entry labels even filtered.')
self.assertListEqual(
Y.tolist(), [
True,
False,
True,
False,
True,
], 'Y should represent the entry labels even some labels are filtered.')

def test_split_dataset(self) -> None:
N = 100
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_fit(self) -> None:
])
features = ['a', 'b', 'c']
iters = 1
train.fit(X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH)
train.fit(X, Y, X, Y, features, iters, WEIGHTS_FILE_PATH, LOG_FILE_PATH)
with open(WEIGHTS_FILE_PATH) as f:
weights = f.read().splitlines()
top_feature = weights[0].split('\t')[0]
Expand Down