Skip to content

Julia-XAI/ExplainableAI.jl

Repository files navigation

ExplainableAI.jl


Documentation
Build Status
Testing Aqua JET
Code Style Code Style: Blue ColPrac
Citation

Explainable AI in Julia.

This package implements interpretability methods for black-box classifiers, with an emphasis on local explanations and attribution maps in input space. The only requirement for the model is that it is differentiable1. It is similar to Captum and Zennit for PyTorch and iNNvestigate for Keras models.

Installation

This package supports Julia ≥1.10. To install it, open the Julia REPL and run

julia> ]add ExplainableAI

Example

Let's explain why an image of a castle is classified as such by a vision model:

using ExplainableAI
using VisionHeatmaps         # visualization of explanations as heatmaps
using Zygote                 # load autodiff backend for gradient-based methods
using Flux, Metalhead        # pre-trained vision models in Flux
using DataAugmentation       # input preprocessing
using HTTP, FileIO, ImageIO  # load image from URL
using ImageInTerminal        # show heatmap in terminal

# Load & prepare model
model = VGG(16, pretrain=true)

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/Julia-XAI/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url) 

# Preprocess input
mean = (0.485f0, 0.456f0, 0.406f0)
std  = (0.229f0, 0.224f0, 0.225f0)
tfm = CenterResizeCrop((224, 224)) |> ImageToTensor() |> Normalize(mean, std)
input = apply(tfm, Image(img))               # apply DataAugmentation transform
input = reshape(input.data, 224, 224, 3, :)  # unpack data and add batch dimension

# Run XAI method
analyzer = SmoothGrad(model)
expl = analyze(input, analyzer)  # or: expl = analyzer(input)
heatmap(expl)                    # show heatmap using VisionHeatmaps.jl

By default, explanations are computed for the class with the highest activation. We can also compute explanations for a specific class, e.g. the one at output index 5:

analyze(input, analyzer, 5)  # for explanation 
heatmap(input, analyzer, 5)  # for heatmap
Analyzer Heatmap for class "castle" Heatmap for class "street sign"
InputTimesGradient
Gradient
SmoothGrad
IntegratedGradients

Tip

The heatmaps shown above were created using a VGG-16 vision model from Metalhead.jl that was pre-trained on the ImageNet dataset.

Since ExplainableAI.jl can be used outside of Deep Learning models and Flux.jl, we have omitted specific models and inputs from the code snippet above. The full code used to generate the heatmaps can be found here.

Depending on the method, the applied heatmapping defaults differ: sensitivity-based methods (e.g. Gradient) default to a grayscale color scheme, whereas attribution-based methods (e.g. InputTimesGradient) default to a red-white-blue color scheme. Red color indicates regions of positive relevance towards the selected class, whereas regions in blue are of negative relevance. More information on heatmapping presets can be found in the Julia-XAI documentation.

Warning

ExplainableAI.jl used to contain Layer-wise Relevance Propagation (LRP). Since version v0.7.0, LRP is now available as part of a separate package in the Julia-XAI ecosystem, called RelevancePropagation.jl.

Analyzer Heatmap for class "castle" Heatmap for class "street sign"
LRP with EpsilonPlus composite
LRP with EpsilonPlusFlat composite
LRP with EpsilonAlpha2Beta1 composite
LRP with EpsilonAlpha2Beta1Flat composite
LRP with EpsilonGammaBox composite
LRP with ZeroRule (discouraged)

Video Demonstration

Check out our talk at JuliaCon 2022 for a demonstration of the package.

Methods

Currently, the following analyzers are implemented:

  • Gradient
  • InputTimesGradient
  • SmoothGrad
  • IntegratedGradients
  • GradCAM

One of the design goals of the Julia-XAI ecosystem is extensibility. To implement an XAI method, take a look at the common interface defined in XAIBase.jl.

Roadmap

In the future, we would like to include:

Contributions are welcome!

Acknowledgements

Adrian Hill acknowledges support by the Federal Ministry of Education and Research (BMBF) for the Berlin Institute for the Foundations of Learning and Data (BIFOLD) (01IS18037A).

Footnotes

  1. The automatic differentiation backend can be selected using ADTypes.jl.