Skip to content

Interpretable and efficient predictors using pre-trained language models. Scikit-learn compatible.

License

Notifications You must be signed in to change notification settings

microsoft/augmented-interpretable-models

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Augmenting Interpretable Models with LLMs during Training

This repo contains code to reproduce the experiments in the Aug-imodels paper (Nature Communications, 2023). For a simple scikit-learn interface to use Aug-imodels, use the imodelsX library. Below is a quickstart example.

Installation: pip install imodelsx

from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugLinearClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Reference:

@misc{ch2022augmenting,
    title={Augmenting Interpretable Models with LLMs during Training},
    author={Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao},
    year={2022},
    eprint={2209.11799},
    archivePrefix={arXiv},
    primaryClass={cs.AI}
}