Skip to content

Commit

Permalink
Add __init__ methods, extract attrs from them
Browse files Browse the repository at this point in the history
We need __init__ methods for type-checking, and extract attrs
from them to avoid repetition.
  • Loading branch information
lopuhin committed Nov 7, 2016
1 parent d7c8d45 commit 1ed2fe7
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 48 deletions.
141 changes: 93 additions & 48 deletions eli5/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,78 @@

from typing import Dict, List, Tuple, Union

import attr
from .base_utils import attrs


@attr.s
@attrs
class Explanation(object):
""" An explanation for classifier or regressor,
it can either explain weights or a single prediction.
"""
# Explanation meta-information
estimator = attr.ib() # type: str
description = attr.ib(default=None) # type: str
error = attr.ib(default=None) # type: str
method = attr.ib(default=None) # type: str
is_regression = attr.ib(default=False) # type: bool
# Actual explanations
targets = attr.ib(default=None) # type: List[TargetExplanation]
feature_importances = attr.ib(default=None) # type: FeatureWeights
decision_tree = attr.ib(default=None) # type: TreeInfo
def __init__(self,
estimator, # type: str
description=None, # type: str
error=None, # type: str
method=None, # type: str
is_regression=False, # type: bool
targets=None, # type: List[TargetExplanation]
feature_importances=None, # type: FeatureWeights
decision_tree=None, # type: TreeInfo
):
self.estimator = estimator
self.description = description
self.error = error
self.method = method
self.is_regression = is_regression
self.targets = targets
self.feature_importances = feature_importances
self.decision_tree = decision_tree

def _repr_html_(self):
from eli5.formatters import format_as_html, fields
return format_as_html(self, force_weights=False, show=fields.WEIGHTS)


@attr.s
@attrs
class TargetExplanation(object):
""" Explanation for a single target or class.
Feature weights are stored in the :feature_weights: attribute,
and features highlighted in text in the :weighted_spans: attribute.
"""
target = attr.ib() # type: str
feature_weights = attr.ib() # type: FeatureWeights
proba = attr.ib(default=None) # type: float
score = attr.ib(default=None) # type: float
weighted_spans = attr.ib(default=None) # type: WeightedSpans
def __init__(self,
target, # type: str
feature_weights, # type: FeatureWeights
proba=None, # type: float
score=None, # type: float
weighted_spans=None, # type: WeightedSpans
):
self.target = target
self.feature_weights = feature_weights
self.proba = proba
self.score = score
self.weighted_spans = weighted_spans


Feature = Union[str, Dict] # Dict is currently used for unhashed features


@attr.s
@attrs
class FeatureWeights(object):
""" Weights for top features, :pos: for positive and :neg: for negative,
sorted by descending absolute value.
Number of remaining positive and negative features are stored in
:pos_remaining: and :neg_remaining: attributes.
"""
pos = attr.ib() # type: List[Tuple[Feature, float]]
neg = attr.ib() # type: List[Tuple[Feature, float]]
pos_remaining = attr.ib(default=0) # type: int
neg_remaining = attr.ib(default=0) # type: int
def __init__(self,
pos, # type: List[Tuple[Feature, float]]
neg, # type: List[Tuple[Feature, float]]
pos_remaining=0, # type: int
neg_remaining=0, # type: int
):
self.pos = pos
self.neg = neg
self.pos_remaining = pos_remaining
self.neg_remaining = neg_remaining


WeightedSpan = Tuple[
Expand All @@ -62,46 +83,70 @@ class FeatureWeights(object):
]


@attr.s
@attrs
class WeightedSpans(object):
""" Features highlighted in text. :analyzer: is a type of the analyzer
(for example "char" or "word"), and :document: is a pre-processed document
before applying the analyzed. :weighted_spans: holds a list of spans
(see above) for features found in text (span indices correspond to
:document:), and :other: holds weights for features not highlighted in text.
"""
analyzer = attr.ib() # type: str
document = attr.ib() # type: str
weighted_spans = attr.ib() # type: List[WeightedSpan]
other = attr.ib(default=None) # type: FeatureWeights


@attr.s
def __init__(self,
analyzer, # type: str
document, # type: str
weighted_spans, # type: List[WeightedSpan]
other=None, # type: FeatureWeights
):
self.analyzer = analyzer
self.document = document
self.weighted_spans = weighted_spans
self.other = other


@attrs
class TreeInfo(object):
""" Information about the decision tree. :criterion: is the name of
the function to measure the quality of a split, :tree: holds all nodes
of the tree, and :graphviz: is the tree rendered in graphviz .dot format.
"""
criterion = attr.ib() # type: str
tree = attr.ib() # type: NodeInfo
graphviz = attr.ib() # type: str
def __init__(self,
criterion, # type: str
tree, # type: NodeInfo
graphviz, # type: str
):
self.criterion = criterion
self.tree = tree
self.graphviz = graphviz


@attr.s
@attrs
class NodeInfo(object):
""" A node in a binary tree.
Pointers to left and right children are in :left: and :right: attributes.
"""
id = attr.ib()
is_leaf = attr.ib() # type: bool
value = attr.ib()
value_ratio = attr.ib()
impurity = attr.ib()
samples = attr.ib()
sample_ratio = attr.ib()
feature_name = attr.ib(default=None)
# for non-leafs
feature_id = attr.ib(default=None)
threshold = attr.ib(default=None)
left = attr.ib(default=None) # type: NodeInfo
right = attr.ib(default=None) # type: NodeInfo
def __init__(self,
id,
is_leaf, # type: bool
value,
value_ratio,
impurity,
samples,
sample_ratio,
feature_name=None,
feature_id=None,
threshold=None,
left=None, # type: NodeInfo
right=None, # type: NodeInfo
):
self.id = id
self.is_leaf = is_leaf
self.value = value
self.value_ratio = value_ratio
self.impurity = impurity
self.samples = samples
self.sample_ratio = sample_ratio
self.feature_name = feature_name
self.feature_id = feature_id
self.threshold = threshold
self.left = left
self.right = right
24 changes: 24 additions & 0 deletions eli5/base_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import inspect

import attr


def attrs(class_):
""" Like attr.s, but with attributes extracted from __init__ method signature.
It is useful if we still want __init__ for proper type-checking and
do not want to repeat attribute definitions in the class body.
"""
attrs_kwargs = {}
for method in ['repr', 'cmp', 'hash']:
if '__{}__'.format(method) in class_.__dict__:
# Allow to redefine a special method (or else attr.s will do it)
attrs_kwargs[method] = False
init_args = inspect.getargspec(class_.__init__)
defaults_shift = len(init_args.args) - len(init_args.defaults or []) - 1
these = {}
for idx, arg in enumerate(init_args.args[1:]):
attrib_kwargs = {}
if idx >= defaults_shift:
attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift]
these[arg] = attr.ib(**attrib_kwargs)
return attr.s(class_, these=these, init=False, **attrs_kwargs)
50 changes: 50 additions & 0 deletions tests/test_base_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import attr

from eli5.base_utils import attrs


def test_attrs_with_default():

@attrs
class WithDefault(object):
def __init__(self, x, y=1):
self.x = x
self.y = y

x_attr, y_attr = attr.fields(WithDefault)
assert x_attr.name == 'x'
assert y_attr.name == 'y'
assert x_attr.default is attr.NOTHING
assert y_attr.default == 1

assert WithDefault(1) == WithDefault(1)
assert WithDefault(1, 1) != WithDefault(1, 2)


def test_attrs_without_default():

@attrs
class WithoutDefault(object):
def __init__(self, x):
self.x = x

x_attr, = attr.fields(WithoutDefault)
assert x_attr.name == 'x'
assert x_attr.default is attr.NOTHING

assert WithoutDefault(1) == WithoutDefault(1)
assert WithoutDefault(1) != WithoutDefault(2)


def test_attrs_with_repr():

@attrs
class WithRepr(object):
def __init__(self, x):
self.x = x

def __repr__(self):
return 'foo'

assert hash(WithRepr(1)) == hash(WithRepr(1))
assert repr(WithRepr(2)) == 'foo'

0 comments on commit 1ed2fe7

Please sign in to comment.