-
Notifications
You must be signed in to change notification settings - Fork 29
/
app.py
89 lines (76 loc) · 2.79 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# local
import content
# 3rd party
import streamlit as st
import numpy as np
from dtreeviz.trees import *
from sklearn.datasets import load_boston
from sklearn.tree import DecisionTreeRegressor
# built-in
import base64
@st.cache()
def load_data():
boston = load_boston()
X = boston.data
y = boston.target * 10_000
feature_names = boston.feature_names
return X, y, feature_names
@st.cache()
def fit_dtree(X, y):
"""
With bigger data, or a more complex model,
you'd probably want to train offline, then
load into app.
"""
dtree = DecisionTreeRegressor(max_depth=2)
dtree.fit(X, y)
return dtree
def render_svg(svg):
"""
Renders the given svg string.
https://gist.github.com/treuille/8b9cbfec270f7cda44c5fc398361b3b1
"""
b64 = base64.b64encode(svg.encode('utf-8')).decode("utf-8")
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64
st.write(html, unsafe_allow_html=True)
if __name__ == "__main__":
### loading things up
text = st.sidebar.title("Built on:")
# logo = st.sidebar.image("images/streamlit.png")
logo = st.sidebar.image("images/aws.png")
logo = st.sidebar.image("images/docker.png")
# text content
title = st.title("Explainable ML")
intro = st.markdown(content.intro)
model_explanation = st.markdown(content.model_explanation)
slider_explanation = st.markdown(content.slider_explanation)
# fitting model
X, y, feature_names = load_data()
dtree = fit_dtree(X, y)
# sliders for new predictions
RM = st.slider("RM: average number of rooms per dwelling.",
min_value=3.6,
max_value=8.7,
value=6.0,
step=.1)
LSTAT = st.slider("LSTAT: percentage of the population denoted as lower status.",
min_value=2.0,
max_value=37.0,
value=14.0,
step=.1)
new_observation = np.array([0, 0, 0, 0, 0, RM, 0, 0, 0, 0, 0, 0, LSTAT])
# viz the predictions path
viz = dtreeviz(dtree,
X,
y,
target_name='price',
orientation ='LR', # left-right orientation
feature_names=feature_names,
X=new_observation) # need to give single observation for prediction
viz.save("images/prediction_path.svg")
# read in svg prediction path and display
with open("images/prediction_path.svg", "r") as f:
svg = f.read()
render_svg(svg)
prediction_explanation = st.markdown(f"""According to the model, a house with {round(RM, 1)} rooms located in a neighborhood that is {LSTAT/100:.1%} lower status
should be valued at approximately ${dtree.predict(new_observation.reshape(1, -1)).item():,.0f}.""")