Code to run the experiments of the Neurips 2022 paper On the Symmetries of Deep Learning Models and their Internal Representations.
This repository is currently organized into a module model_symmetries
with
submodules stitching
and alignment
, corresponding to sections 4 and 5 of
the paper (for the network dissection results of section 6 we used the
implementation at
https://github.com/CSAILVision/NetDissect-Lite).
In addition there are some submodules containing code shared across stitching
and alignment
, namely
models.py,
datasets.py
,train.py
andplotting.py
(self explanatory)zoo.py
: utilities to train a bunch of models from independent random seedsconstants.py
: specify a directory in which to store data/models/results by defining the variabledata_dir
.
The key classes for stitching layers and stitched models are in stitching.py
.
In particular, we direct attention towards the Birkhoff
class, which
implements for our approach using PGD on the Birkhoff polytope of doubly
stochastic matrices.
train.py
has more options than is typical, due to a few major implementation
considerations:
- The need to make sure that when stitching, we only update parameters of the stitching layer.
- The overhead of PGD and extra
$-\ell_2$ regularization. - The necessity of a no-grad training epoch before validation.
The main experiment script is cifar10_stitching.py
. This also has many
options, due to the number of combinations of model/stitching layer type we
consider.
In order to run the experiments stitching Compact Convolutional Transformers,
you will need
https://github.com/SHI-Labs/Compact-Transformers,
which is included as a Git submodule of this repository at
model_symmetries/ct
. To initialize and update it, run
git submodule init && git submodule update
Core functions are located in alignment.py
. The wreath_{procrustes,cka}
(the group
plotting.py
contains functions for displaying stitching penalties and
dissimilarity metrics, which can be run in the notebook plotting.ipynb
.
We ran these experiments on a cluster managed by
SLURM -- files ending in
.slurm
are SLURM batch files. In order to distribute the many sweeps in these
experiments across nodes of the cluster, we submitted batches to the queue using
loops found in the bash scripts (files ending in .sh
). WARNING: executing
these scripts will consume many GPU days.
If you find this code useful, please cite our paper.
@article{modelsyms2022,
doi = {10.48550/ARXIV.2205.14258},
url = {https://arxiv.org/abs/2205.14258},
author = {Godfrey, Charles and Brown, Davis and Emerson, Tegan and Kvinge, Henry},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {On the Symmetries of Deep Learning Models and their Internal Representations},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
This research was supported by the Mathematics for Artificial Reasoning in Science (MARS) initiative at Pacific Northwest National Laboratory. It was conducted under the Laboratory Directed Research and Development (LDRD) Program at at Pacific Northwest National Laboratory (PNNL), a multiprogram National Laboratory operated by Battelle Memorial Institute for the U.S. Department of Energy under Contract DE-AC05-76RL01830.