Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
fix issues with pandas pylint and flake8 (#427)
Browse files Browse the repository at this point in the history
* fix issues with pandas pylint and flake8

* fix requirements

* fix req tests

* remove comment

* fix catboost req tests

Co-authored-by: mike0sv <[email protected]>
  • Loading branch information
madhur-tandon and mike0sv authored Oct 3, 2022
1 parent 0dbe116 commit 779b8b3
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 11 deletions.
4 changes: 3 additions & 1 deletion mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ def read_pickle_with_unnamed(*args, **kwargs):


def read_json_reset_index(*args, **kwargs):
return pd.read_json(*args, **kwargs).reset_index(drop=True)
return pd.read_json( # pylint: disable=no-member
*args, **kwargs
).reset_index(drop=True)


def read_html(*args, **kwargs):
Expand Down
13 changes: 12 additions & 1 deletion mlem/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,10 @@ def wrapper(pickler: "RequirementAnalyzer", obj):
else:
pickler.save(o)

if is_from_installable_module(obj):
if (
is_from_installable_module(obj)
or get_object_base_module(obj) is mlem
):
return f(pickler, obj)

# to add from local imports inside user (non PIP package) code
Expand Down Expand Up @@ -514,6 +517,7 @@ def _should_ignore(self, mod: ModuleType):
or is_private_module(mod)
or is_pseudo_module(mod)
or is_builtin_module(mod)
or mod in self._modules
)

def add_requirement(self, obj_or_module):
Expand All @@ -533,6 +537,11 @@ def add_requirement(self, obj_or_module):
module = obj_or_module

if module is not None and not self._should_ignore(module):
base_module = get_base_module(module)
if is_installable_module(base_module):
if base_module in self._modules:
return
module = base_module
self._modules.add(module)
if is_local_module(module):
# add imports of this module
Expand All @@ -553,6 +562,8 @@ def save(self, obj, save_persistent_id=True):
if id(obj) in self.seen or isinstance(obj, IGNORE_TYPES_REQ):
return None
self.seen.add(id(obj))
if get_object_base_module(obj) in self._modules:
return None
self.add_requirement(obj)
try:
return super().save(obj, save_persistent_id)
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ ignore =
E266, # Too many leading '#' for block comment
W503, # Line break occurred before a binary operator
B008, # Do not perform function calls in argument defaults: conflicts with typer
P1, # unindexed parameters in the str.format, see:
P1, # unindexed parameters in the str.format, see:
B902, # Invalid first argument 'cls' used for instance method.
B024, # ABCs without methods
# https://pypi.org/project/flake8-string-format/
max_line_length = 79
max-complexity = 15
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_catboost_model(catboost_model_fixture, pandas_data, tmpdir, request):
),
)

expected_requirements = {"catboost", "pandas", "numpy", "scipy"}
expected_requirements = {"catboost", "pandas"}
reqs = set(cbmw.get_requirements().modules)
assert all(r in reqs for r in expected_requirements)
assert cbmw.model is catboost_model
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_model__predict_not_dataset(model):
@long
def test_model__dump_load(tmpdir, model, data_np, local_fs):
# pandas is not required, but if it is installed, it is imported by lightgbm
expected_requirements = {"lightgbm", "numpy", "scipy", "pandas"}
expected_requirements = {"lightgbm", "numpy"}
assert set(model.get_requirements().modules) == expected_requirements

artifacts = model.dump(LOCAL_STORAGE, tmpdir)
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def f(x):

sig = Signature.from_method(f, auto_infer=True, x=data)

assert set(get_object_requirements(sig).modules) == {"pandas", "numpy"}
assert set(get_object_requirements(sig).modules) == {"pandas"}


# Copyright 2019 Zyfra
Expand Down
11 changes: 8 additions & 3 deletions tests/contrib/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_model_type__dump_load(tmpdir, model, inp_data, request):
def test_model_type_lgb__dump_load(tmpdir, lgbm_model, inp_data):
model_type = ModelAnalyzer.analyze(lgbm_model, sample_data=inp_data)

expected_requirements = {"sklearn", "lightgbm", "pandas", "numpy", "scipy"}
expected_requirements = {"sklearn", "lightgbm", "numpy"}
reqs = model_type.get_requirements().expanded
assert set(reqs.modules) == expected_requirements
assert reqs.of_type(UnixPackageRequirement) == [
Expand All @@ -164,11 +164,16 @@ def test_model_type_lgb__dump_load(tmpdir, lgbm_model, inp_data):
]


def test_pipeline_requirements(lgbm_model):
def test_pipeline_requirements(lgbm_model, inp_data):
model = Pipeline(steps=[("model", lgbm_model)])
meta = MlemModel.from_obj(model)

expected_requirements = {"sklearn", "lightgbm", "pandas", "numpy", "scipy"}
expected_requirements = {"sklearn", "lightgbm"}
assert set(meta.requirements.modules) == expected_requirements

meta = MlemModel.from_obj(model, sample_data=np.array(inp_data))

expected_requirements = {"sklearn", "lightgbm", "numpy"}
assert set(meta.requirements.modules) == expected_requirements


Expand Down
3 changes: 1 addition & 2 deletions tests/contrib/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def test_model__predict_not_dmatrix(model):

@long
def test_model__dump_load(tmpdir, model, dmatrix_np, local_fs):
# pandas is not required, but it is conditionally imported by some Booster methods
expected_requirements = {"xgboost", "numpy", "scipy", "pandas"}
expected_requirements = {"xgboost", "numpy"}
assert set(model.get_requirements().modules) == expected_requirements

artifacts = model.dump(LOCAL_STORAGE, tmpdir)
Expand Down

0 comments on commit 779b8b3

Please sign in to comment.