-
Notifications
You must be signed in to change notification settings - Fork 332
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add __init__ methods, extract attrs from them
We need __init__ methods for type-checking, and extract attrs from them to avoid repetition.
- Loading branch information
Showing
3 changed files
with
167 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import inspect | ||
|
||
import attr | ||
|
||
|
||
def attrs(class_): | ||
""" Like attr.s, but with attributes extracted from __init__ method signature. | ||
It is useful if we still want __init__ for proper type-checking and | ||
do not want to repeat attribute definitions in the class body. | ||
""" | ||
attrs_kwargs = {} | ||
for method in ['repr', 'cmp', 'hash']: | ||
if '__{}__'.format(method) in class_.__dict__: | ||
# Allow to redefine a special method (or else attr.s will do it) | ||
attrs_kwargs[method] = False | ||
init_args = inspect.getargspec(class_.__init__) | ||
defaults_shift = len(init_args.args) - len(init_args.defaults or []) - 1 | ||
these = {} | ||
for idx, arg in enumerate(init_args.args[1:]): | ||
attrib_kwargs = {} | ||
if idx >= defaults_shift: | ||
attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift] | ||
these[arg] = attr.ib(**attrib_kwargs) | ||
return attr.s(class_, these=these, init=False, **attrs_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import attr | ||
|
||
from eli5.base_utils import attrs | ||
|
||
|
||
def test_attrs_with_default(): | ||
|
||
@attrs | ||
class WithDefault(object): | ||
def __init__(self, x, y=1): | ||
self.x = x | ||
self.y = y | ||
|
||
x_attr, y_attr = attr.fields(WithDefault) | ||
assert x_attr.name == 'x' | ||
assert y_attr.name == 'y' | ||
assert x_attr.default is attr.NOTHING | ||
assert y_attr.default == 1 | ||
|
||
assert WithDefault(1) == WithDefault(1) | ||
assert WithDefault(1, 1) != WithDefault(1, 2) | ||
|
||
|
||
def test_attrs_without_default(): | ||
|
||
@attrs | ||
class WithoutDefault(object): | ||
def __init__(self, x): | ||
self.x = x | ||
|
||
x_attr, = attr.fields(WithoutDefault) | ||
assert x_attr.name == 'x' | ||
assert x_attr.default is attr.NOTHING | ||
|
||
assert WithoutDefault(1) == WithoutDefault(1) | ||
assert WithoutDefault(1) != WithoutDefault(2) | ||
|
||
|
||
def test_attrs_with_repr(): | ||
|
||
@attrs | ||
class WithRepr(object): | ||
def __init__(self, x): | ||
self.x = x | ||
|
||
def __repr__(self): | ||
return 'foo' | ||
|
||
assert hash(WithRepr(1)) == hash(WithRepr(1)) | ||
assert repr(WithRepr(2)) == 'foo' |