diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 979226b11ff5f..b6909b3386b71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -340,8 +340,10 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = instance.wordVectors.getVectors val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( - sc, instance.wordVectors.wordIndex.size, instance.getVectorSize) + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) sparkSession.createDataFrame(dataSeq) .repartition(numPartitions) .write @@ -351,16 +353,20 @@ object Word2VecModel extends MLReadable[Word2VecModel] { private[feature] object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ def calculateNumberOfPartitions( - sc: SparkContext, + bufferSizeInBytes: Long, numWords: Int, vectorSize: Int): Int = { val floatSize = 4L // Use Long to help avoid overflow val averageWordSize = 15 - // [SPARK-11994] - We want to partition the model in partitions smaller than - // spark.kryoserializer.buffer.max - val bufferSizeInBytes = Utils.byteStringAsBytes( - sc.conf.get("spark.kryoserializer.buffer.max", "64m")) // Calculate the approximate size of the model. // Assuming an average word size of 15 bytes, the formula is: // (floatSize * vectorSize + 15) * numWords diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index ea6d8d1606f04..6183606a7b2ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -189,12 +190,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec read/write numPartitions calculation") { - val tinyModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( - sc, numWords = 10, vectorSize = 5) - assert(tinyModelNumPartitions === 1) - val mediumModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( - sc, numWords = 1000000, vectorSize = 5000) - assert(mediumModelNumPartitions > 1) + val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5) + assert(smallModelNumPartitions === 1) + val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000) + assert(largeModelNumPartitions > 1) } test("Word2Vec read/write") {