-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdata_process.py
139 lines (111 loc) · 5.25 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
attr_classes = {
"workclass": ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
'Never-worked', 'Retired'],
"education": ['Bachelors', 'Some-college', '11th', 'HS-grad', 'Prof-school', 'Assoc-acdm', 'Assoc-voc', '9th',
'7th-8th', '12th', 'Masters', '1st-4th', '10th', 'Doctorate', '5th-6th', 'Preschool'],
"marital-status": ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed',
'Married-spouse-absent', 'Married-AF-spouse'],
"occupation": ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
'Handlers-cleaners', 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving',
'Priv-house-serv', 'Protective-serv', 'Armed-Forces', 'Retired', 'Student', 'None'],
"relationship": ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
"race": ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
"sex": ['Female', 'Male'],
"native-country": ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany',
'Outlying-US(Guam-USVI-etc)', 'India', 'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran',
'Honduras', 'Philippines', 'Italy', 'Poland', 'Jamaica', 'Vietnam', 'Mexico', 'Portugal',
'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti', 'Columbia',
'Hungary', 'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador',
'Trinadad&Tobago', 'Peru', 'Hong', 'Holand-Netherlands'],
"label": ['<=50K', '>50K']
}
def int_to_one_hot(val: int, total: int) -> List[int]:
res = [0] * total
if val >= 0:
res[val] = 1
return res
def one_hot_to_int(arr: List[int]) -> int:
assert sum(arr) == 1, "arr should be an one hot array"
res = arr.index(1)
return res
def df_to_arr(df: pd.DataFrame) -> np.ndarray:
names = df.columns.tolist()
rows = len(df)
arr = []
for i in range(rows):
arr_row = []
for col in names:
if col in attr_classes:
raw_val = df.iloc[i][col]
if raw_val == "?":
val = -1
else:
val = attr_classes[col].index(raw_val)
if col == "label":
arr_row.append(val)
else:
arr_row.extend(int_to_one_hot(val, len(attr_classes[col])))
else:
val = df.iloc[i][col]
arr_row.append(val)
arr.append(arr_row)
res = np.array(arr, dtype=np.float32)
return res
def load_csv(filename: str) -> pd.DataFrame:
names = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation",
"relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country",
"label"]
df = pd.read_csv(filename, header=0, names=names, sep=r",\s+")
return df
def get_mean_std(df: pd.DataFrame):
res = {}
for col in ["age", "capital-gain", "capital-loss", "hours-per-week"]:
mean = df[col].mean()
std = df[col].std()
res[col] = {"mean": mean, "std": std}
return res
def convert_csv_to_arr(filename: str, mean_std: Dict[str, Dict[str, float]]) -> np.ndarray:
df = load_csv(filename)
# remove meaningless columns
df.drop(columns=["fnlwgt", "education-num"], inplace=True)
# fill null value
df.loc[df.workclass == "Never-worked", ["occupation"]] = "None"
df.loc[(df.age < 24) & (df.occupation == "?"), ["workclass", "occupation"]] = ["Never-worked", "Student"]
df.loc[(df.age > 60) & (df.occupation == "?"), ["workclass", "occupation"]] = ["Retired", "Retired"]
# normalize continuous columns
for col in mean_std:
mean = mean_std[col]["mean"]
std = mean_std[col]["std"]
df[col] = (df[col] - mean) / std
arr = df_to_arr(df)
return arr
def split_feature(arr: np.ndarray) -> Tuple[np.ndarray, ...]:
feature, label = arr[:, :-1], arr[:, -1:]
a_feature_size = 50
a_feature = feature[:, :a_feature_size]
b_feature = feature[:, a_feature_size:]
return a_feature, b_feature, label
if __name__ == '__main__':
if not os.path.exists("dataset"):
os.makedirs("dataset/a", exist_ok=True)
os.makedirs("dataset/b", exist_ok=True)
os.makedirs("dataset/c", exist_ok=True)
df = load_csv("adult.train.csv")
# calculate and save mean and std
mean_std = get_mean_std(df)
arr = convert_csv_to_arr("adult.train.csv", mean_std)
np.savez("dataset/adult.train.npz", arr)
a, b, label = split_feature(arr)
np.savez("dataset/a/train.npz", a)
np.savez("dataset/b/train.npz", b)
np.savez("dataset/c/train.npz", label)
arr = convert_csv_to_arr("adult.test.csv", mean_std)
np.savez("dataset/adult.test.npz", arr)
a, b, label = split_feature(arr)
np.savez("dataset/a/test.npz", a)
np.savez("dataset/b/test.npz", b)
np.savez("dataset/c/test.npz", label)