diff --git a/eli5/_feature_weights.py b/eli5/_feature_weights.py index 47d71f9f..620f48fe 100644 --- a/eli5/_feature_weights.py +++ b/eli5/_feature_weights.py @@ -24,7 +24,7 @@ def _get_top_features(feature_names, coef, top): no more than ``num_neg`` negative features. """ if isinstance(top, (list, tuple)): - num_pos, num_neg = top + num_pos, num_neg = list(top) # "list" is just for mypy pos = _get_top_positive_features(feature_names, coef, num_pos) neg = _get_top_negative_features(feature_names, coef, num_neg) else: diff --git a/eli5/base.py b/eli5/base.py index f71e9e75..d4f101d7 100644 --- a/eli5/base.py +++ b/eli5/base.py @@ -2,57 +2,82 @@ from typing import Dict, List, Tuple, Union -import attr +from .base_utils import attrs -@attr.s +# @attrs decorator used in this file calls @attr.s(slots=True), +# creating attr.ib entries based on the signature of __init__. + + +@attrs class Explanation(object): """ An explanation for classifier or regressor, it can either explain weights or a single prediction. """ - # Explanation meta-information - estimator = attr.ib() # type: str - description = attr.ib(default=None) # type: str - error = attr.ib(default=None) # type: str - method = attr.ib(default=None) # type: str - is_regression = attr.ib(default=False) # type: bool - # Actual explanations - targets = attr.ib(default=None) # type: List[TargetExplanation] - feature_importances = attr.ib(default=None) # type: FeatureWeights - decision_tree = attr.ib(default=None) # type: TreeInfo + def __init__(self, + estimator, # type: str + description=None, # type: str + error=None, # type: str + method=None, # type: str + is_regression=False, # type: bool + targets=None, # type: List[TargetExplanation] + feature_importances=None, # type: FeatureWeights + decision_tree=None, # type: TreeInfo + ): + self.estimator = estimator + self.description = description + self.error = error + self.method = method + self.is_regression = is_regression + self.targets = targets + self.feature_importances = feature_importances + self.decision_tree = decision_tree def _repr_html_(self): from eli5.formatters import format_as_html, fields return format_as_html(self, force_weights=False, show=fields.WEIGHTS) -@attr.s +@attrs class TargetExplanation(object): """ Explanation for a single target or class. Feature weights are stored in the :feature_weights: attribute, and features highlighted in text in the :weighted_spans: attribute. """ - target = attr.ib() # type: str - feature_weights = attr.ib() # type: FeatureWeights - proba = attr.ib(default=None) # type: float - score = attr.ib(default=None) # type: float - weighted_spans = attr.ib(default=None) # type: WeightedSpans + def __init__(self, + target, # type: str + feature_weights, # type: FeatureWeights + proba=None, # type: float + score=None, # type: float + weighted_spans=None, # type: WeightedSpans + ): + self.target = target + self.feature_weights = feature_weights + self.proba = proba + self.score = score + self.weighted_spans = weighted_spans Feature = Union[str, Dict] # Dict is currently used for unhashed features -@attr.s +@attrs class FeatureWeights(object): """ Weights for top features, :pos: for positive and :neg: for negative, sorted by descending absolute value. Number of remaining positive and negative features are stored in :pos_remaining: and :neg_remaining: attributes. """ - pos = attr.ib() # type: List[Tuple[Feature, float]] - neg = attr.ib() # type: List[Tuple[Feature, float]] - pos_remaining = attr.ib(default=0) # type: int - neg_remaining = attr.ib(default=0) # type: int + def __init__(self, + pos, # type: List[Tuple[Feature, float]] + neg, # type: List[Tuple[Feature, float]] + pos_remaining=0, # type: int + neg_remaining=0, # type: int + ): + self.pos = pos + self.neg = neg + self.pos_remaining = pos_remaining + self.neg_remaining = neg_remaining WeightedSpan = Tuple[ @@ -62,7 +87,7 @@ class FeatureWeights(object): ] -@attr.s +@attrs class WeightedSpans(object): """ Features highlighted in text. :analyzer: is a type of the analyzer (for example "char" or "word"), and :document: is a pre-processed document @@ -70,38 +95,62 @@ class WeightedSpans(object): (see above) for features found in text (span indices correspond to :document:), and :other: holds weights for features not highlighted in text. """ - analyzer = attr.ib() # type: str - document = attr.ib() # type: str - weighted_spans = attr.ib() # type: List[WeightedSpan] - other = attr.ib(default=None) # type: FeatureWeights - - -@attr.s + def __init__(self, + analyzer, # type: str + document, # type: str + weighted_spans, # type: List[WeightedSpan] + other=None, # type: FeatureWeights + ): + self.analyzer = analyzer + self.document = document + self.weighted_spans = weighted_spans + self.other = other + + +@attrs class TreeInfo(object): """ Information about the decision tree. :criterion: is the name of the function to measure the quality of a split, :tree: holds all nodes of the tree, and :graphviz: is the tree rendered in graphviz .dot format. """ - criterion = attr.ib() # type: str - tree = attr.ib() # type: NodeInfo - graphviz = attr.ib() # type: str + def __init__(self, + criterion, # type: str + tree, # type: NodeInfo + graphviz, # type: str + ): + self.criterion = criterion + self.tree = tree + self.graphviz = graphviz -@attr.s +@attrs class NodeInfo(object): """ A node in a binary tree. Pointers to left and right children are in :left: and :right: attributes. """ - id = attr.ib() - is_leaf = attr.ib() # type: bool - value = attr.ib() - value_ratio = attr.ib() - impurity = attr.ib() - samples = attr.ib() - sample_ratio = attr.ib() - feature_name = attr.ib(default=None) - # for non-leafs - feature_id = attr.ib(default=None) - threshold = attr.ib(default=None) - left = attr.ib(default=None) # type: NodeInfo - right = attr.ib(default=None) # type: NodeInfo + def __init__(self, + id, + is_leaf, # type: bool + value, + value_ratio, + impurity, + samples, + sample_ratio, + feature_name=None, + feature_id=None, + threshold=None, + left=None, # type: NodeInfo + right=None, # type: NodeInfo + ): + self.id = id + self.is_leaf = is_leaf + self.value = value + self.value_ratio = value_ratio + self.impurity = impurity + self.samples = samples + self.sample_ratio = sample_ratio + self.feature_name = feature_name + self.feature_id = feature_id + self.threshold = threshold + self.left = left + self.right = right diff --git a/eli5/base_utils.py b/eli5/base_utils.py new file mode 100644 index 00000000..930ee2d6 --- /dev/null +++ b/eli5/base_utils.py @@ -0,0 +1,27 @@ +import inspect + +import attr + + +def attrs(class_): + """ Like attr.s with slots=True, + but with attributes extracted from __init__ method signature. + slots=True ensures that signature matches what really happens + (we can't define different attributes on self). + It is useful if we still want __init__ for proper type-checking and + do not want to repeat attribute definitions in the class body. + """ + attrs_kwargs = {} + for method in ['repr', 'cmp', 'hash']: + if '__{}__'.format(method) in class_.__dict__: + # Allow to redefine a special method (or else attr.s will do it) + attrs_kwargs[method] = False + init_args = inspect.getargspec(class_.__init__) + defaults_shift = len(init_args.args) - len(init_args.defaults or []) - 1 + these = {} + for idx, arg in enumerate(init_args.args[1:]): + attrib_kwargs = {} + if idx >= defaults_shift: + attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift] + these[arg] = attr.ib(**attrib_kwargs) + return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) diff --git a/eli5/formatters/text.py b/eli5/formatters/text.py index 4c721558..d301b2df 100644 --- a/eli5/formatters/text.py +++ b/eli5/formatters/text.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import import six +from typing import List from . import fields from .features import FormattedFeatureName @@ -14,7 +15,7 @@ def format_as_text(expl, show=fields.ALL): - lines = [] + lines = [] # type: List[str] if expl.error: # always shown lines.extend(_error_lines(expl)) diff --git a/eli5/lime/samplers.py b/eli5/lime/samplers.py index d09fbca7..4d501d1d 100644 --- a/eli5/lime/samplers.py +++ b/eli5/lime/samplers.py @@ -2,7 +2,7 @@ from __future__ import absolute_import import abc import six -from typing import Tuple +from typing import List, Tuple import numpy as np @@ -145,8 +145,8 @@ class UnivariateKernelDensitySampler(_BaseKernelDensitySampler): of the features instead of generating totally new examples. """ def fit(self, X, y=None): - self.kdes_ = [] - self.grids_ = [] + self.kdes_ = [] # type: List[KernelDensity] + self.grids_ = [] # type: List[GridSearchCV] num_features = X.shape[-1] for i in range(num_features): grid, kde = self._fit_kde(self.kde, X[:, i].reshape(-1, 1)) diff --git a/eli5/sklearn/text.py b/eli5/sklearn/text.py index 1d193ba4..c928a6c6 100644 --- a/eli5/sklearn/text.py +++ b/eli5/sklearn/text.py @@ -1,4 +1,5 @@ import re +from typing import Set, Tuple from six.moves import xrange from sklearn.feature_extraction.text import VectorizerMixin @@ -56,7 +57,7 @@ def _get_features(feature): def _get_other(feature_weights, feature_weights_dict, found_features): # search for items that were not accounted at all. other_items = [] - accounted_keys = set() + accounted_keys = set() # type: Set[Tuple[str, int]] for feature, (_, key) in feature_weights_dict.items(): if key not in found_features and key not in accounted_keys: group, idx = key diff --git a/eli5/sklearn/unhashing.py b/eli5/sklearn/unhashing.py index a77398d5..9d49d8b3 100644 --- a/eli5/sklearn/unhashing.py +++ b/eli5/sklearn/unhashing.py @@ -6,7 +6,7 @@ from collections import defaultdict, Counter from itertools import chain -from typing import List, Iterable, Any +from typing import List, Iterable, Any, Dict import numpy as np from sklearn.base import BaseEstimator, TransformerMixin @@ -183,7 +183,7 @@ def _get_collisions(indices): Return a dict ``{column_id: [possible term ids]}`` with collision information. """ - collisions = defaultdict(list) + collisions = defaultdict(list) # type: Dict[int, List[int]] for term_id, hash_id in enumerate(indices): collisions[hash_id].append(term_id) return dict(collisions) diff --git a/eli5/sklearn/utils.py b/eli5/sklearn/utils.py index 377bff15..62dc6a10 100644 --- a/eli5/sklearn/utils.py +++ b/eli5/sklearn/utils.py @@ -96,7 +96,7 @@ def get_feature_names(clf, vec=None, bias_name='', feature_names=None): if feature_names.n_features != num_features: raise ValueError("feature_names has a wrong n_features: " "expected=%d, got=%d" % (num_features, - len(feature_names))) + feature_names.n_features)) # Make a shallow copy setting proper bias_name return FeatureNames( feature_names.feature_names, diff --git a/tests/test_base_utils.py b/tests/test_base_utils.py new file mode 100644 index 00000000..3658c646 --- /dev/null +++ b/tests/test_base_utils.py @@ -0,0 +1,62 @@ +import attr +import pytest + +from eli5.base_utils import attrs + + +def test_attrs_with_default(): + + @attrs + class WithDefault(object): + def __init__(self, x, y=1): + self.x = x + self.y = y + + x_attr, y_attr = attr.fields(WithDefault) + assert x_attr.name == 'x' + assert y_attr.name == 'y' + assert x_attr.default is attr.NOTHING + assert y_attr.default == 1 + + assert WithDefault(1) == WithDefault(1) + assert WithDefault(1, 1) != WithDefault(1, 2) + + +def test_attrs_without_default(): + + @attrs + class WithoutDefault(object): + def __init__(self, x): + self.x = x + + x_attr, = attr.fields(WithoutDefault) + assert x_attr.name == 'x' + assert x_attr.default is attr.NOTHING + + assert WithoutDefault(1) == WithoutDefault(1) + assert WithoutDefault(1) != WithoutDefault(2) + + +def test_attrs_with_repr(): + + @attrs + class WithRepr(object): + def __init__(self, x): + self.x = x + + def __repr__(self): + return 'foo' + + assert hash(WithRepr(1)) == hash(WithRepr(1)) + assert repr(WithRepr(2)) == 'foo' + + +def test_bad_init(): + + @attrs + class BadInit(object): + def __init__(self, x): + self._x = x + + with pytest.raises(AttributeError): + BadInit(1) diff --git a/tox.ini b/tox.ini index ba30054e..df7c7d61 100644 --- a/tox.ini +++ b/tox.ini @@ -33,4 +33,4 @@ deps= {[testenv]deps} mypy-lang commands= - mypy --silent-imports eli5 + mypy --silent-imports --check-untyped-defs eli5