Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

WIP / experimental: expose model internals to predictor #2211

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions allennlp/common/introspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import inspect
from functools import wraps

# A global flag that indicates whether function results
# should be stored as part of the containing model.
_FLAGS = {'STORE_FUNCTION_RESULTS': False}

# Helper function to update the flag
def store_function_results(flag: bool) -> None:
_FLAGS['STORE_FUNCTION_RESULTS'] = flag

# Decorator for storing function results with the containing model.
def store_result(func):
@wraps(func)
def inner(*args, **kwargs):
result = func(*args, **kwargs)

if _FLAGS['STORE_FUNCTION_RESULTS']:
# Import here to avoid circularity
from allennlp.models.model import Model

stack = inspect.stack()

calling_module = stack[1].frame.f_locals['self']

if isinstance(calling_module, Model):
secret_functions = getattr(calling_module, '_stored_function_results', [])
secret_functions.append([func.__name__, result])
setattr(calling_module, '_stored_function_results', secret_functions)

return result

return inner
3 changes: 2 additions & 1 deletion allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

from allennlp.common.checks import ConfigurationError
from allennlp.common.introspection import store_result

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -213,7 +214,7 @@ def get_dropout_mask(dropout_probability: float, tensor_for_masking: torch.Tenso
dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
return dropout_mask


@store_result
def masked_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
``torch.nn.functional.softmax(vector)`` does not work if some elements of ``vector`` should be
Expand Down
54 changes: 49 additions & 5 deletions allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from allennlp.common import Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.common.introspection import store_function_results
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
Expand Down Expand Up @@ -32,9 +33,13 @@ class Predictor(Registrable):
a ``Predictor`` is a thin wrapper around an AllenNLP model that handles JSON -> JSON predictions
that can be used for serving models through the web API or making predictions in bulk.
"""
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
def __init__(self,
model: Model,
dataset_reader: DatasetReader,
return_model_internals: bool = False) -> None:
self._model = model
self._dataset_reader = dataset_reader
self._return_model_internals = return_model_internals

def load_line(self, line: str) -> JsonDict: # pylint: disable=no-self-use
"""
Expand All @@ -55,7 +60,40 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
return self.predict_instance(instance)

def predict_instance(self, instance: Instance) -> JsonDict:
internal_module_results = {}
hooks = []

if self._return_model_internals:
store_function_results(True)

def add_output(idx: int):
def _add_output(mod, _, outputs):
internal_module_results[idx] = {"name": str(mod), "output": outputs}
return _add_output

hooks = [module.register_forward_hook(add_output(i))
for i, module in enumerate(self._model.modules())
if module != self._model]

outputs = self._model.forward_on_instance(instance)

if self._return_model_internals:
# Collect results of modules
if internal_module_results:
outputs['_internal_module_results'] = internal_module_results

# Collect results of functions
internal_function_results = getattr(self._model, '_stored_function_results', [])
if internal_function_results:
outputs['_internal_function_results'] = internal_function_results[:]

# And clean up
internal_function_results.clear()
store_function_results(False)

for hook in hooks:
hook.remove()

return sanitize(outputs)

def _json_to_instance(self, json_dict: JsonDict) -> Instance:
Expand Down Expand Up @@ -89,7 +127,10 @@ def _batch_json_to_instances(self, json_dicts: List[JsonDict]) -> List[Instance]
return instances

@classmethod
def from_path(cls, archive_path: str, predictor_name: str = None) -> 'Predictor':
def from_path(cls,
archive_path: str,
predictor_name: str = None,
return_model_internals: bool = False) -> 'Predictor':
"""
Instantiate a :class:`Predictor` from an archive path.

Expand All @@ -104,10 +145,13 @@ def from_path(cls, archive_path: str, predictor_name: str = None) -> 'Predictor'
-------
A Predictor instance.
"""
return Predictor.from_archive(load_archive(archive_path), predictor_name)
return Predictor.from_archive(load_archive(archive_path), predictor_name, return_model_internals)

@classmethod
def from_archive(cls, archive: Archive, predictor_name: str = None) -> 'Predictor':
def from_archive(cls,
archive: Archive,
predictor_name: str = None,
return_model_internals: bool = False) -> 'Predictor':
"""
Instantiate a :class:`Predictor` from an :class:`~allennlp.models.archival.Archive`;
that is, from the result of training a model. Optionally specify which `Predictor`
Expand All @@ -129,4 +173,4 @@ def from_archive(cls, archive: Archive, predictor_name: str = None) -> 'Predicto
model = archive.model
model.eval()

return Predictor.by_name(predictor_name)(model, dataset_reader)
return Predictor.by_name(predictor_name)(model, dataset_reader, return_model_internals)
23 changes: 23 additions & 0 deletions allennlp/tests/predictors/bidaf_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-self-use,invalid-name
from pytest import approx

from allennlp.common.introspection import store_function_results
from allennlp.common.testing import AllenNlpTestCase
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
Expand Down Expand Up @@ -35,6 +36,28 @@ def test_uses_named_inputs(self):
assert all(isinstance(x, float) for x in probs)
assert sum(probs) == approx(1.0)

def test_model_internals(self):
archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz')
predictor = Predictor.from_archive(archive, 'machine-comprehension', return_model_internals=True)

inputs = {
"question": "What kind of test succeeded on its first attempt?",
"passage": "One time I was writing a unit test, and it succeeded on the first attempt."
}

result = predictor.predict_json(inputs)
imr = result.get('_internal_module_results')
assert imr is not None
assert len(imr) == 25

linear_50_1 = imr[23]
assert "Linear(in_features=50, out_features=1, bias=True)" in linear_50_1["name"]
assert len(linear_50_1['output']) == 17
assert all(len(a) == 1 for a in linear_50_1['output'])

ifr = result.get('_internal_function_results')
assert any(name == 'masked_softmax' for name, value in ifr)

def test_batch_prediction(self):
inputs = [
{
Expand Down