-
Notifications
You must be signed in to change notification settings - Fork 2
/
split_data.py
30 lines (26 loc) · 1.01 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import orchest
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import copy
# Retrieve the data from the previous step.
data = orchest.get_inputs() # data = [(df_data, df_target)]
data, targets = data["data"]
print('Input target = %s' %targets)
# Print messages are useful when you are keeping an eye on the logs of
# a pipeline step.
print("Splitting the data into train and test...")
#### we are going to infer what type of problem it is by looking at the target variable ##
if data[targets[0]].dtype == float:
train, test = train_test_split(data, test_size=0.2, random_state=9)
else:
train, test = train_test_split(data, test_size=0.2, random_state=9, stratify=data[targets])
print(train.shape, test.shape)
#### we need to change target to a string if there is only one label ###
if len(targets) == 1:
target = targets[0]
else:
target = copy.deepcopy(targets)
print('Output target = %s' %target)
orchest.output((train, test, target), name="split_data")
print("Success!")