Skip to content

Commit

Permalink
expand random range
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Jan 19, 2021
1 parent 6bd56ce commit 648b4a6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/integration/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_classification_prediction(spark_context, mode, mnist_data, classificati
train_rdd = to_simple_rdd(spark_context, x_train, y_train)

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 100))
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(train_rdd, epochs=epochs, batch_size=batch_size,
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_regression_prediction(spark_context, mode, boston_housing_dataset, regr
sgd = SGD(lr=0.000001)
regression_model.compile(sgd, 'mse', ['mae'])
# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 100))
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(train_rdd, epochs=epochs, batch_size=batch_size,
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_evaluate(spark_context, mode, mnist_data, classification_model):
train_rdd = to_simple_rdd(spark_context, x_train, y_train)

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 100))
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(train_rdd, epochs=epochs, batch_size=batch_size,
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_training_modes_classification(spark_context, mode, mnist_data, classifi
rdd = to_simple_rdd(spark_context, x_train, y_train)

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 100))
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
Expand All @@ -47,7 +47,7 @@ def test_training_modes_regression(spark_context, mode, boston_housing_dataset,
epochs = 10
sgd = SGD(lr=0.0000001)
regression_model.compile(sgd, 'mse', ['mae'])
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 100))
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode, port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
Expand All @@ -74,7 +74,7 @@ def test_training_asynchronous_socket(spark_context, mode, mnist_data, classific

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch',
mode=mode, parameter_server_mode='socket', port=4000 + random.randint(0, 100))
mode=mode, parameter_server_mode='socket', port=4000 + random.randint(0, 200))

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size,
Expand Down

0 comments on commit 648b4a6

Please sign in to comment.