Skip to content

Commit

Permalink
refactor tests and fix pickle imports
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Jan 24, 2021
1 parent 1ad3867 commit ca3948b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 194 deletions.
2 changes: 1 addition & 1 deletion elephas/utils/sockets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from six.moves import cPickle as pickle
import pickle
from socket import gethostbyname, gethostname
import os

Expand Down
3 changes: 2 additions & 1 deletion examples/hyperparam_optimization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pickle

from pyspark import SparkContext, SparkConf

from hyperopt import STATUS_OK
from hyperas.distributions import choice, uniform
import six.moves.cPickle as pickle

from elephas.hyperparam import HyperParamModel

Expand Down
94 changes: 94 additions & 0 deletions tests/integration/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import random
from math import isclose

from tensorflow.keras.optimizers import SGD

from elephas.spark_model import SparkModel
from elephas.utils.rdd_utils import to_simple_rdd

import pytest
import numpy as np


# enumerate possible combinations for training mode and parameter server for a classification model
@pytest.mark.parametrize('mode,parameter_server_mode', [('synchronous', None),
('asynchronous', 'http'),
('asynchronous', 'socket'),
('hogwild', 'http'),
('hogwild', 'socket')])
def test_training_classification(spark_context, mode, parameter_server_mode, mnist_data, classification_model):
# Define basic parameters
batch_size = 64
epochs = 10

# Load data
x_train, y_train, x_test, y_test = mnist_data
x_train = x_train[:1000]
y_train = y_train[:1000]

sgd = SGD(lr=0.1)
classification_model.compile(sgd, 'categorical_crossentropy', ['acc'])

# Build RDD from numpy features and labels
rdd = to_simple_rdd(spark_context, x_train, y_train)

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch',
mode=mode, parameter_server_mode=parameter_server_mode, port=4000 + random.randint(0, 500))

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
verbose=0, validation_split=0.1)

# run inference on trained spark model
predictions = spark_model.predict(x_test)
# run evaluation on trained spark model
evals = spark_model.evaluate(x_test, y_test)

# assert we can supply rdd and get same prediction results when supplying numpy array
test_rdd = spark_context.parallelize(x_test)
assert [np.argmax(x) for x in predictions] == [np.argmax(x) for x in spark_model.predict(test_rdd)]

# assert we get the same prediction result with calling predict on keras model directly
assert [np.argmax(x) for x in predictions] == [np.argmax(x) for x in spark_model.master_network.predict(x_test)]

# assert we get the same evaluation results when calling evaluate on keras model directly
assert isclose(evals, spark_model.master_network.evaluate(x_test, y_test)[0], abs_tol=0.01)


# enumerate possible combinations for training mode and parameter server for a regression model
@pytest.mark.parametrize('mode,parameter_server_mode', [('synchronous', None),
('asynchronous', 'http'),
('asynchronous', 'socket'),
('hogwild', 'http'),
('hogwild', 'socket')])
def test_training_regression(spark_context, mode, parameter_server_mode, boston_housing_dataset, regression_model):
x_train, y_train, x_test, y_test = boston_housing_dataset
rdd = to_simple_rdd(spark_context, x_train, y_train)

# Define basic parameters
batch_size = 64
epochs = 10
sgd = SGD(lr=0.0000001)
regression_model.compile(sgd, 'mse', ['mae'])
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode,
parameter_server_mode=parameter_server_mode, port=4000 + random.randint(0, 500))

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
verbose=0, validation_split=0.1)

# run inference on trained spark model
predictions = spark_model.predict(x_test)
# run evaluation on trained spark model
evals = spark_model.evaluate(x_test, y_test)

# assert we can supply rdd and get same prediction results when supplying numpy array
test_rdd = spark_context.parallelize(x_test)
assert all(np.isclose(x, y, 0.01) for x, y in zip(predictions, spark_model.predict(test_rdd)))

# assert we get the same prediction result with calling predict on keras model directly
assert all(np.isclose(x, y, 0.01) for x, y in zip(predictions, spark_model.master_network.predict(x_test)))

# assert we get the same evaluation results when calling evaluate on keras model directly
assert isclose(evals, spark_model.master_network.evaluate(x_test, y_test)[0], abs_tol=0.01)
107 changes: 0 additions & 107 deletions tests/integration/test_prediction.py

This file was deleted.

85 changes: 0 additions & 85 deletions tests/integration/test_training.py

This file was deleted.

0 comments on commit ca3948b

Please sign in to comment.