Augmenting Interpretable Models with LLMs during Training
This repo contains code to reproduce the experiments in the Aug-imodels paper (Nature Communications, 2023). For a simple scikit-learn interface to use Aug-imodels, use the imodelsX library. Below is a quickstart example.
Installation: pip install imodelsx
from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))
# fit model
m = AugLinearClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])
# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))
# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
print('\t', k, round(v, 2))
Reference:
@misc{ch2022augmenting,
title={Augmenting Interpretable Models with LLMs during Training},
author={Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao},
year={2022},
eprint={2209.11799},
archivePrefix={arXiv},
primaryClass={cs.AI}
}