Skip to content

Commit

Permalink
[dask] reduce test times (#3786)
Browse files Browse the repository at this point in the history
* speed up tests

* [dask] reduce test times
  • Loading branch information
jameslamb authored Jan 18, 2021
1 parent d2c5545 commit c871496
Showing 1 changed file with 48 additions and 28 deletions.
76 changes: 48 additions & 28 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,29 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
def test_classifier(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
n_estimators=10,
num_leaves=10
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute()
s1 = accuracy_score(dy, p1)
p1 = p1.compute()

local_classifier = lightgbm.LGBMClassifier()
local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)
p2_proba = local_classifier.predict_proba(X)
s2 = local_classifier.score(X, y)

assert_eq(s1, s2)

assert_eq(p1, p2)
assert_eq(y, p1)
assert_eq(y, p2)
assert_eq(p1_proba, p2_proba, atol=0.3)


def test_training_does_not_fail_on_port_conflicts(client):
Expand All @@ -98,7 +105,9 @@ def test_training_does_not_fail_on_port_conflicts(client):

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400
local_listen_port=12400,
n_estimators=5,
num_leaves=5
)
for i in range(5):
dask_classifier.fit(
Expand All @@ -110,31 +119,19 @@ def test_training_does_not_fail_on_port_conflicts(client):
assert dask_classifier.booster_


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict_proba(dX)
p1 = p1.compute()

local_classifier = lightgbm.LGBMClassifier()
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict_proba(X)

assert_eq(p1, p2, atol=0.3)


def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array')

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
n_estimators=10,
num_leaves=10
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.to_local().predict(dX)

local_classifier = lightgbm.LGBMClassifier()
local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)

Expand All @@ -147,14 +144,19 @@ def test_classifier_local_predict(client, listen_port):
def test_regressor(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)

dask_regressor = dlgbm.DaskLGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42)
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
s1 = r2_score(dy, p1)
p1 = p1.compute()

local_regressor = lightgbm.LGBMRegressor(seed=42)
local_regressor = lightgbm.LGBMRegressor(seed=42, num_leaves=10)
local_regressor.fit(X, y, sample_weight=w)
s2 = local_regressor.score(X, y)
p2 = local_regressor.predict(X)
Expand All @@ -173,12 +175,25 @@ def test_regressor(output, client, listen_port):
def test_regressor_quantile(output, client, listen_port, alpha):
X, y, w, dX, dy, dw = _create_data('regression', output=output)

dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha)
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0]

local_regressor = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha)
local_regressor = lightgbm.LGBMRegressor(
seed=42,
objective='quantile',
alpha=alpha,
n_estimatores=10,
num_leaves=10
)
local_regressor.fit(X, y, sample_weight=w)
p2 = local_regressor.predict(X)
q2 = np.count_nonzero(y < p2) / y.shape[0]
Expand All @@ -191,7 +206,12 @@ def test_regressor_quantile(output, client, listen_port, alpha):
def test_regressor_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output='array')

dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42)
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
n_estimators=10,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
p2 = dask_regressor.to_local().predict(X)
Expand Down

0 comments on commit c871496

Please sign in to comment.