title | hide_title | status |
---|---|---|
Text Explainers |
true |
stable |
In this example, we use LIME and Kernel SHAP explainers to explain a text classification model.
First we import the packages and define some UDFs and a plotting function we will need later.
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.functions import vector_to_array
from synapse.ml.explainers import *
from synapse.ml.featurize.text import TextFeaturizer
from synapse.ml.core.platform import *
vec_access = udf(lambda v, i: float(v[i]), FloatType())
Load training data, and convert rating to binary label.
data = (
spark.read.parquet("wasbs://[email protected]/BookReviewsFromAmazon10K.parquet")
.withColumn("label", (col("rating") > 3).cast(LongType()))
.select("label", "text")
.cache()
)
display(data)
We train a text classification model, and randomly sample 10 rows to explain.
train, test = data.randomSplit([0.60, 0.40])
pipeline = Pipeline(
stages=[
TextFeaturizer(
inputCol="text",
outputCol="features",
useStopWordsRemover=True,
useIDF=True,
minDocFreq=20,
numFeatures=1 << 16,
),
LogisticRegression(maxIter=100, regParam=0.005, labelCol="label", featuresCol="features"),
]
)
model = pipeline.fit(train)
prediction = model.transform(test)
explain_instances = prediction.orderBy(rand()).limit(10)
def plotConfusionMatrix(df, label, prediction, classLabels):
from synapse.ml.plot import confusionMatrix
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(4.5, 4.5))
confusionMatrix(df, label, prediction, classLabels)
if running_on_synapse():
plt.show()
else:
display(fig)
plotConfusionMatrix(model.transform(test), "label", "prediction", [0, 1])
First we use the LIME text explainer to explain the model's predicted probability for a given observation.
lime = TextLIME(
model=model,
outputCol="weights",
inputCol="text",
targetCol="probability",
targetClasses=[1],
tokensCol="tokens",
samplingFraction=0.7,
numSamples=2000,
)
lime_results = (
lime.transform(explain_instances)
.select("tokens", "weights", "r2", "probability", "text")
.withColumn("probability", vec_access("probability", lit(1)))
.withColumn("weights", vector_to_array(col("weights").getItem(0)))
.withColumn("r2", vec_access("r2", lit(0)))
.withColumn("tokens_weights", arrays_zip("tokens", "weights"))
)
display(lime_results.select("probability", "r2", "tokens_weights", "text").orderBy(col("probability").desc()))
Then we use the Kernel SHAP text explainer to explain the model's predicted probability for a given observation.
Notice that we drop the base value from the SHAP output before displaying the SHAP values. The base value is the model output for an empty string.
shap = TextSHAP(
model=model,
outputCol="shaps",
inputCol="text",
targetCol="probability",
targetClasses=[1],
tokensCol="tokens",
numSamples=5000,
)
shap_results = (
shap.transform(explain_instances)
.select("tokens", "shaps", "r2", "probability", "text")
.withColumn("probability", vec_access("probability", lit(1)))
.withColumn("shaps", vector_to_array(col("shaps").getItem(0)))
.withColumn("shaps", slice(col("shaps"), lit(2), size(col("shaps"))))
.withColumn("r2", vec_access("r2", lit(0)))
.withColumn("tokens_shaps", arrays_zip("tokens", "shaps"))
)
display(shap_results.select("probability", "r2", "tokens_shaps", "text").orderBy(col("probability").desc()))