diff --git a/eli5/base.py b/eli5/base.py index f71e9e75..921ef59f 100644 --- a/eli5/base.py +++ b/eli5/base.py @@ -2,57 +2,78 @@ from typing import Dict, List, Tuple, Union -import attr +from .base_utils import attrs -@attr.s +@attrs class Explanation(object): """ An explanation for classifier or regressor, it can either explain weights or a single prediction. """ - # Explanation meta-information - estimator = attr.ib() # type: str - description = attr.ib(default=None) # type: str - error = attr.ib(default=None) # type: str - method = attr.ib(default=None) # type: str - is_regression = attr.ib(default=False) # type: bool - # Actual explanations - targets = attr.ib(default=None) # type: List[TargetExplanation] - feature_importances = attr.ib(default=None) # type: FeatureWeights - decision_tree = attr.ib(default=None) # type: TreeInfo + def __init__(self, + estimator, # type: str + description=None, # type: str + error=None, # type: str + method=None, # type: str + is_regression=False, # type: bool + targets=None, # type: List[TargetExplanation] + feature_importances=None, # type: FeatureWeights + decision_tree=None, # type: TreeInfo + ): + self.estimator = estimator + self.description = description + self.error = error + self.method = method + self.is_regression = is_regression + self.targets = targets + self.feature_importances = feature_importances + self.decision_tree = decision_tree def _repr_html_(self): from eli5.formatters import format_as_html, fields return format_as_html(self, force_weights=False, show=fields.WEIGHTS) -@attr.s +@attrs class TargetExplanation(object): """ Explanation for a single target or class. Feature weights are stored in the :feature_weights: attribute, and features highlighted in text in the :weighted_spans: attribute. """ - target = attr.ib() # type: str - feature_weights = attr.ib() # type: FeatureWeights - proba = attr.ib(default=None) # type: float - score = attr.ib(default=None) # type: float - weighted_spans = attr.ib(default=None) # type: WeightedSpans + def __init__(self, + target, # type: str + feature_weights, # type: FeatureWeights + proba=None, # type: float + score=None, # type: float + weighted_spans=None, # type: WeightedSpans + ): + self.target = target + self.feature_weights = feature_weights + self.proba = proba + self.score = score + self.weighted_spans = weighted_spans Feature = Union[str, Dict] # Dict is currently used for unhashed features -@attr.s +@attrs class FeatureWeights(object): """ Weights for top features, :pos: for positive and :neg: for negative, sorted by descending absolute value. Number of remaining positive and negative features are stored in :pos_remaining: and :neg_remaining: attributes. """ - pos = attr.ib() # type: List[Tuple[Feature, float]] - neg = attr.ib() # type: List[Tuple[Feature, float]] - pos_remaining = attr.ib(default=0) # type: int - neg_remaining = attr.ib(default=0) # type: int + def __init__(self, + pos, # type: List[Tuple[Feature, float]] + neg, # type: List[Tuple[Feature, float]] + pos_remaining=0, # type: int + neg_remaining=0, # type: int + ): + self.pos = pos + self.neg = neg + self.pos_remaining = pos_remaining + self.neg_remaining = neg_remaining WeightedSpan = Tuple[ @@ -62,7 +83,7 @@ class FeatureWeights(object): ] -@attr.s +@attrs class WeightedSpans(object): """ Features highlighted in text. :analyzer: is a type of the analyzer (for example "char" or "word"), and :document: is a pre-processed document @@ -70,38 +91,62 @@ class WeightedSpans(object): (see above) for features found in text (span indices correspond to :document:), and :other: holds weights for features not highlighted in text. """ - analyzer = attr.ib() # type: str - document = attr.ib() # type: str - weighted_spans = attr.ib() # type: List[WeightedSpan] - other = attr.ib(default=None) # type: FeatureWeights - - -@attr.s + def __init__(self, + analyzer, # type: str + document, # type: str + weighted_spans, # type: List[WeightedSpan] + other=None, # type: FeatureWeights + ): + self.analyzer = analyzer + self.document = document + self.weighted_spans = weighted_spans + self.other = other + + +@attrs class TreeInfo(object): """ Information about the decision tree. :criterion: is the name of the function to measure the quality of a split, :tree: holds all nodes of the tree, and :graphviz: is the tree rendered in graphviz .dot format. """ - criterion = attr.ib() # type: str - tree = attr.ib() # type: NodeInfo - graphviz = attr.ib() # type: str + def __init__(self, + criterion, # type: str + tree, # type: NodeInfo + graphviz, # type: str + ): + self.criterion = criterion + self.tree = tree + self.graphviz = graphviz -@attr.s +@attrs class NodeInfo(object): """ A node in a binary tree. Pointers to left and right children are in :left: and :right: attributes. """ - id = attr.ib() - is_leaf = attr.ib() # type: bool - value = attr.ib() - value_ratio = attr.ib() - impurity = attr.ib() - samples = attr.ib() - sample_ratio = attr.ib() - feature_name = attr.ib(default=None) - # for non-leafs - feature_id = attr.ib(default=None) - threshold = attr.ib(default=None) - left = attr.ib(default=None) # type: NodeInfo - right = attr.ib(default=None) # type: NodeInfo + def __init__(self, + id, + is_leaf, # type: bool + value, + value_ratio, + impurity, + samples, + sample_ratio, + feature_name=None, + feature_id=None, + threshold=None, + left=None, # type: NodeInfo + right=None, # type: NodeInfo + ): + self.id = id + self.is_leaf = is_leaf + self.value = value + self.value_ratio = value_ratio + self.impurity = impurity + self.samples = samples + self.sample_ratio = sample_ratio + self.feature_name = feature_name + self.feature_id = feature_id + self.threshold = threshold + self.left = left + self.right = right diff --git a/eli5/base_utils.py b/eli5/base_utils.py new file mode 100644 index 00000000..d7263efa --- /dev/null +++ b/eli5/base_utils.py @@ -0,0 +1,24 @@ +import inspect + +import attr + + +def attrs(class_): + """ Like attr.s, but with attributes extracted from __init__ method signature. + It is useful if we still want __init__ for proper type-checking and + do not want to repeat attribute definitions in the class body. + """ + attrs_kwargs = {} + for method in ['repr', 'cmp', 'hash']: + if '__{}__'.format(method) in class_.__dict__: + # Allow to redefine a special method (or else attr.s will do it) + attrs_kwargs[method] = False + init_args = inspect.getargspec(class_.__init__) + defaults_shift = len(init_args.args) - len(init_args.defaults or []) - 1 + these = {} + for idx, arg in enumerate(init_args.args[1:]): + attrib_kwargs = {} + if idx >= defaults_shift: + attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift] + these[arg] = attr.ib(**attrib_kwargs) + return attr.s(class_, these=these, init=False, **attrs_kwargs) diff --git a/tests/test_base_utils.py b/tests/test_base_utils.py new file mode 100644 index 00000000..b7d80c92 --- /dev/null +++ b/tests/test_base_utils.py @@ -0,0 +1,50 @@ +import attr + +from eli5.base_utils import attrs + + +def test_attrs_with_default(): + + @attrs + class WithDefault(object): + def __init__(self, x, y=1): + self.x = x + self.y = y + + x_attr, y_attr = attr.fields(WithDefault) + assert x_attr.name == 'x' + assert y_attr.name == 'y' + assert x_attr.default is attr.NOTHING + assert y_attr.default == 1 + + assert WithDefault(1) == WithDefault(1) + assert WithDefault(1, 1) != WithDefault(1, 2) + + +def test_attrs_without_default(): + + @attrs + class WithoutDefault(object): + def __init__(self, x): + self.x = x + + x_attr, = attr.fields(WithoutDefault) + assert x_attr.name == 'x' + assert x_attr.default is attr.NOTHING + + assert WithoutDefault(1) == WithoutDefault(1) + assert WithoutDefault(1) != WithoutDefault(2) + + +def test_attrs_with_repr(): + + @attrs + class WithRepr(object): + def __init__(self, x): + self.x = x + + def __repr__(self): + return 'foo' + + assert hash(WithRepr(1)) == hash(WithRepr(1)) + assert repr(WithRepr(2)) == 'foo'