-
Notifications
You must be signed in to change notification settings - Fork 1
/
tracking_pycaret.py
221 lines (175 loc) · 6.9 KB
/
tracking_pycaret.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
About
Use MLflow and CrateDB to track the metrics, parameters, and outcomes of an ML
experiment program using PyCaret. It uses the Real-world sales forecasting benchmark data
dataset from 4TU.ResearchData.
- https://github.com/crate/mlflow-cratedb
- https://mlflow.org/docs/latest/tracking.html
Usage
Before running the program, define the `MLFLOW_TRACKING_URI` environment
variable, in order to record events and metrics either directly into the database,
or by submitting them to an MLflow Tracking Server.
# Use CrateDB database directly
export MLFLOW_TRACKING_URI="crate://crate@localhost/?schema=mlflow"
# Use MLflow Tracking Server
export MLFLOW_TRACKING_URI=http://127.0.0.1:5000
Optionally, you might set the `CRATEDB_HTTP_URL` environment variable, in order
to define the CrateDB HTTP URL. The default value is `http://crate@localhost:4200`.
# Use custom CrateDB HTTP URL to connect to a CrateDB Cloud instance
# Replace <username>, <password>, and <instance> with the values
# from your CrateDB Cloud instance
export CRATEDB_HTTP_URL="https://<username>:<password>@<instance>.aks1.westeurope.azure.cratedb.net:4200"
Resources
- https://mlflow.org/
- https://github.com/crate/crate
- https://github.com/pycaret/pycaret
- https://data.4tu.nl/articles/dataset/Real-world_sales_forecasting_benchmark_data_-_Extended_version/14406134/1
"""
import os
import time
import numpy as np
import pandas as pd
from crate import client
from mlflow import get_tracking_uri
from mlflow.models import infer_signature
from mlflow.sklearn import log_model
from pycaret.time_series import blend_models, compare_models, finalize_model, save_model, setup, tune_model
import mlflow_cratedb # noqa: F401
def connect_database():
"""
Connect to CrateDB, and return database connection object.
"""
dburi = os.getenv("CRATEDB_HTTP_URL", "http://crate@localhost:4200")
return client.connect(dburi)
def table_exists(table_name: str) -> bool:
"""
Check if database table exists.
"""
conn = connect_database()
cursor = conn.cursor()
sql = (
f"SELECT table_name FROM information_schema.tables " # noqa: S608
f"WHERE table_name = '{table_name}' AND table_schema = CURRENT_SCHEMA"
)
cursor.execute(sql)
rowcount = cursor.rowcount
cursor.close()
conn.close()
return rowcount > 0
def import_data(data_table_name: str):
"""
Download Real-world sales forecasting benchmark data, and load into database.
"""
target_data = pd.read_csv(
"https://data.4tu.nl/file/539debdb-a325-412d-b024-593f70cba15b/a801f5d4-5dfe-412a-ace2-a64f93ad0010"
)
related_data = pd.read_csv(
"https://data.4tu.nl/file/539debdb-a325-412d-b024-593f70cba15b/f2bd27bd-deeb-4933-bed7-29325ee05c2e",
header=None,
)
related_data.columns = ["item", "org", "date", "unit_price"]
data = target_data.merge(related_data, on=["item", "org", "date"])
data["total_sales"] = data["unit_price"] * data["quantity"]
data["date"] = pd.to_datetime(data["date"])
# Split the data into chunks of 1000 rows each for better insert performance
chunk_size = 1000
chunks = np.array_split(data, int(len(data) / chunk_size))
# Insert the data into CrateDB
with connect_database() as conn:
cursor = conn.cursor()
# Create the table if it doesn't exist
cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {data_table_name}
("item" TEXT,
"org" TEXT,
"date" TIMESTAMP,
"quantity" BIGINT,
"unit_price" DOUBLE PRECISION,
"total_sales" DOUBLE PRECISION)"""
)
# Insert the data in chunks
for chunk in chunks:
cursor.executemany(
f"""INSERT INTO {data_table_name}
(item, org, date, quantity, unit_price, total_sales)
VALUES (?, ?, ?, ?, ?, ?)""", # noqa: S608
list(chunk.itertuples(index=False, name=None)),
)
cursor.close()
def refresh_table(table_name: str):
"""
Flush/Synchronize CrateDB write operations.
Refresh the table, to make sure the data is up-to-date.
https://cratedb.com/docs/crate/reference/en/latest/sql/statements/refresh.html
"""
with connect_database() as conn:
cursor = conn.cursor()
cursor.execute(f"REFRESH TABLE {table_name}")
cursor.close()
def read_data(table_name: str) -> pd.DataFrame:
"""
Read data from database into pandas DataFrame.
"""
query = f"""
SELECT
DATE_TRUNC('month', date) as month,
SUM(total_sales) AS total_sales
FROM {table_name}
GROUP BY month
ORDER BY month
"""
with connect_database() as conn:
data = pd.read_sql(query, conn)
data["month"] = pd.to_datetime(data["month"], unit="ms")
# Inplace for sort is much gentler to memory
data.sort_values(by=["month"], inplace=True) # noqa: PD002
return data
def run_experiment(data: pd.DataFrame):
"""
Run experiment on DataFrame, using PyCaret. Track it using MLflow.
The mlflow tracking is automatically executed by PyCaret.
"""
# creating a blend of 3 models, which perform best on MASE metric
setup(data, fh=15, target="total_sales", index="month", log_experiment=True, verbose=False)
best3 = compare_models(sort="MASE", n_select=3)
tuned_models = [tune_model(i) for i in best3]
blended = blend_models(estimator_list=tuned_models, optimize="MASE")
best_model = finalize_model(blended)
# saving the model to disk
if not os.path.exists("model"):
os.makedirs("model")
save_model(best_model, "model/crate-salesforecast")
# Create a name for the model
timestamp = int(time.time())
# registering the model with mlflow, but only if MLFLOW_TRACKING_URI is
# set to a tracking server
if not get_tracking_uri().startswith("file://"):
y_pred = best_model.predict()
signature = infer_signature(None, y_pred)
log_model(
sk_model=best_model,
artifact_path="crate-salesforecast",
signature=signature,
registered_model_name=f"crate-salesforecast-model-{timestamp}",
)
else:
print( # noqa: T201
"MLFLOW_TRACKING_URI is not set to a tracking server, so the model will not be registered with mlflow"
)
def main():
"""
Provision dataset, and run experiment.
"""
# Table name where the actual data is stored.
data_table = "sales_data_for_forecast"
# Provision data to operate on, only once.
if not table_exists(data_table):
import_data(data_table)
# Flush/Synchronize CrateDB write operations.
refresh_table(data_table)
# Read data into pandas DataFrame.
data = read_data(data_table)
# Run experiment on data.
run_experiment(data)
if __name__ == "__main__":
main()