diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 39a4d8e9da57..0e3219208a1a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -126,6 +126,21 @@ def is_numpy_1d_array(data): return isinstance(data, np.ndarray) and len(data.shape) == 1 +def is_numpy_column_array(data): + """Check whether data is a column numpy array.""" + if not isinstance(data, np.ndarray): + return False + shape = data.shape + return len(shape) == 2 and shape[1] == 1 + + +def cast_numpy_1d_array_to_dtype(array, dtype): + """Cast numpy 1d array to given dtype.""" + if array.dtype == dtype: + return array + return array.astype(dtype=dtype, copy=False) + + def is_1d_list(data): """Check whether data is a 1-D list.""" return isinstance(data, list) and (not data or is_numeric(data[0])) @@ -134,10 +149,11 @@ def is_1d_list(data): def list_to_1d_numpy(data, dtype=np.float32, name='list'): """Convert data to numpy 1-D array.""" if is_numpy_1d_array(data): - if data.dtype == dtype: - return data - else: - return data.astype(dtype=dtype, copy=False) + return cast_numpy_1d_array_to_dtype(data, dtype) + elif is_numpy_column_array(data): + _log_warning('Converting column-vector to 1d array') + array = data.ravel() + return cast_numpy_1d_array_to_dtype(array, dtype) elif is_1d_list(data): return np.array(data, dtype=dtype, copy=False) elif isinstance(data, pd_Series): diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 2fb17758a599..92c75d8879ee 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -8,6 +8,7 @@ from sklearn.model_selection import train_test_split import lightgbm as lgb +from lightgbm.compat import PANDAS_INSTALLED, pd_Series from .utils import load_breast_cancer @@ -375,3 +376,33 @@ def test_choose_param_value(): "num_trees": 81 } assert original_params == expected_params + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +@pytest.mark.parametrize( + 'y', + [ + np.random.rand(10), + np.random.rand(10, 1), + pd_Series(np.random.rand(10)), + pd_Series(['a', 'b']), + [1] * 10, + [[1], [2]] + ]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +def test_list_to_1d_numpy(y, dtype): + if isinstance(y, np.ndarray) and len(y.shape) == 2: + with pytest.warns(UserWarning, match='column-vector'): + lgb.basic.list_to_1d_numpy(y) + return + elif isinstance(y, list) and isinstance(y[0], list): + with pytest.raises(TypeError): + lgb.basic.list_to_1d_numpy(y) + return + elif isinstance(y, pd_Series) and y.dtype == object: + with pytest.raises(ValueError): + lgb.basic.list_to_1d_numpy(y) + return + result = lgb.basic.list_to_1d_numpy(y, dtype=dtype) + assert result.size == 10 + assert result.dtype == dtype diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index fcb9a6cb1966..67b564d71ed5 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1198,6 +1198,43 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods): assert dask_params[param].default == sklearn_params[param].default, error_msg +@pytest.mark.parametrize('task', tasks) +def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( + task, + client, +): + if task == 'ranking': + _, _, _, _, dX, dy, dw, dg = _create_ranking_data( + output='dataframe', + group=None + ) + model_factory = lgb.DaskLGBMRanker + else: + _, _, _, dX, dy, dw = _create_data( + objective=task, + output='dataframe', + ) + dg = None + if task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + dy = dy.to_dask_array(lengths=True) + dy_col_array = dy.reshape(-1, 1) + assert len(dy_col_array.shape) == 2 and dy_col_array.shape[1] == 1 + + params = { + 'n_estimators': 1, + 'num_leaves': 3, + 'random_state': 0, + 'time_out': 5 + } + model = model_factory(**params) + model.fit(dX, dy_col_array, sample_weight=dw, group=dg) + assert model.fitted_ + client.close(timeout=CLIENT_CLOSE_TIMEOUT) + + def sklearn_checks_to_run(): check_names = [ "check_estimator_get_tags_default_keys", diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 187ae34b2eaa..042b88e9a233 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -18,7 +18,7 @@ import lightgbm as lgb -from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud +from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking sk_version = parse_version(sk_version) if sk_version < parse_version("0.23"): @@ -1192,3 +1192,36 @@ def test_parameters_default_constructible(estimator): name, Estimator = estimator.__class__.__name__, estimator.__class__ # Test that estimators are default-constructible check_parameters_default_constructible(name, Estimator) + + +@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression']) +def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task): + pd = pytest.importorskip("pandas") + if task == 'ranking': + X, y, g = make_ranking() + g = np.bincount(g) + model_factory = lgb.LGBMRanker + elif task == 'classification': + X, y = load_iris(return_X_y=True) + model_factory = lgb.LGBMClassifier + elif task == 'regression': + X, y = load_boston(return_X_y=True) + model_factory = lgb.LGBMRegressor + X = pd.DataFrame(X) + y_col_array = y.reshape(-1, 1) + params = { + 'n_estimators': 1, + 'num_leaves': 3, + 'random_state': 0 + } + with pytest.warns(UserWarning, match='column-vector'): + if task == 'ranking': + model_1d = model_factory(**params).fit(X, y, group=g) + model_2d = model_factory(**params).fit(X, y_col_array, group=g) + else: + model_1d = model_factory(**params).fit(X, y) + model_2d = model_factory(**params).fit(X, y_col_array) + + preds_1d = model_1d.predict(X) + preds_2d = model_2d.predict(X) + np.testing.assert_array_equal(preds_1d, preds_2d)