Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

fix issues with pandas pylint and flake8 #427

Merged
merged 5 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ def read_pickle_with_unnamed(*args, **kwargs):


def read_json_reset_index(*args, **kwargs):
return pd.read_json(*args, **kwargs).reset_index(drop=True)
return pd.read_json( # pylint: disable=no-member
*args, **kwargs
).reset_index(drop=True)


def read_html(*args, **kwargs):
Expand Down
13 changes: 12 additions & 1 deletion mlem/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,10 @@ def wrapper(pickler: "RequirementAnalyzer", obj):
else:
pickler.save(o)

if is_from_installable_module(obj):
if (
is_from_installable_module(obj)
or get_object_base_module(obj) is mlem
):
return f(pickler, obj)

# to add from local imports inside user (non PIP package) code
Expand Down Expand Up @@ -514,6 +517,7 @@ def _should_ignore(self, mod: ModuleType):
or is_private_module(mod)
or is_pseudo_module(mod)
or is_builtin_module(mod)
or mod in self._modules
)

def add_requirement(self, obj_or_module):
Expand All @@ -533,6 +537,11 @@ def add_requirement(self, obj_or_module):
module = obj_or_module

if module is not None and not self._should_ignore(module):
base_module = get_base_module(module)
if is_installable_module(base_module):
if base_module in self._modules:
return
module = base_module
self._modules.add(module)
if is_local_module(module):
# add imports of this module
Expand All @@ -553,6 +562,8 @@ def save(self, obj, save_persistent_id=True):
if id(obj) in self.seen or isinstance(obj, IGNORE_TYPES_REQ):
return None
self.seen.add(id(obj))
if get_object_base_module(obj) in self._modules:
return None
self.add_requirement(obj)
try:
return super().save(obj, save_persistent_id)
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ ignore =
E266, # Too many leading '#' for block comment
W503, # Line break occurred before a binary operator
B008, # Do not perform function calls in argument defaults: conflicts with typer
P1, # unindexed parameters in the str.format, see:
P1, # unindexed parameters in the str.format, see:
B902, # Invalid first argument 'cls' used for instance method.
B024, # ABCs without methods
# https://pypi.org/project/flake8-string-format/
max_line_length = 79
max-complexity = 15
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_catboost_model(catboost_model_fixture, pandas_data, tmpdir, request):
),
)

expected_requirements = {"catboost", "pandas", "numpy", "scipy"}
expected_requirements = {"catboost", "pandas"}
reqs = set(cbmw.get_requirements().modules)
assert all(r in reqs for r in expected_requirements)
assert cbmw.model is catboost_model
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_model__predict_not_dataset(model):
@long
def test_model__dump_load(tmpdir, model, data_np, local_fs):
# pandas is not required, but if it is installed, it is imported by lightgbm
expected_requirements = {"lightgbm", "numpy", "scipy", "pandas"}
expected_requirements = {"lightgbm", "numpy"}
assert set(model.get_requirements().modules) == expected_requirements

artifacts = model.dump(LOCAL_STORAGE, tmpdir)
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def f(x):

sig = Signature.from_method(f, auto_infer=True, x=data)

assert set(get_object_requirements(sig).modules) == {"pandas", "numpy"}
assert set(get_object_requirements(sig).modules) == {"pandas"}


# Copyright 2019 Zyfra
Expand Down
11 changes: 8 additions & 3 deletions tests/contrib/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_model_type__dump_load(tmpdir, model, inp_data, request):
def test_model_type_lgb__dump_load(tmpdir, lgbm_model, inp_data):
model_type = ModelAnalyzer.analyze(lgbm_model, sample_data=inp_data)

expected_requirements = {"sklearn", "lightgbm", "pandas", "numpy", "scipy"}
expected_requirements = {"sklearn", "lightgbm", "numpy"}
reqs = model_type.get_requirements().expanded
assert set(reqs.modules) == expected_requirements
assert reqs.of_type(UnixPackageRequirement) == [
Expand All @@ -164,11 +164,16 @@ def test_model_type_lgb__dump_load(tmpdir, lgbm_model, inp_data):
]


def test_pipeline_requirements(lgbm_model):
def test_pipeline_requirements(lgbm_model, inp_data):
model = Pipeline(steps=[("model", lgbm_model)])
meta = MlemModel.from_obj(model)

expected_requirements = {"sklearn", "lightgbm", "pandas", "numpy", "scipy"}
expected_requirements = {"sklearn", "lightgbm"}
assert set(meta.requirements.modules) == expected_requirements

meta = MlemModel.from_obj(model, sample_data=np.array(inp_data))

expected_requirements = {"sklearn", "lightgbm", "numpy"}
assert set(meta.requirements.modules) == expected_requirements


Expand Down
3 changes: 1 addition & 2 deletions tests/contrib/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def test_model__predict_not_dmatrix(model):

@long
def test_model__dump_load(tmpdir, model, dmatrix_np, local_fs):
# pandas is not required, but it is conditionally imported by some Booster methods
expected_requirements = {"xgboost", "numpy", "scipy", "pandas"}
expected_requirements = {"xgboost", "numpy"}
assert set(model.get_requirements().modules) == expected_requirements

artifacts = model.dump(LOCAL_STORAGE, tmpdir)
Expand Down